Hi Mainak,
That’s a good idea. There was actually a typo in my previous code, which meant it did not interpolate and the plot actually show the comparison between missing channel and “ground truth”.
When comparing correlation between interpolation or missing channels the correlation was much higher.
I did as you suggested and tested it for 50 sample subjects from the eegbci dataset.
Using 3 neighboring bad channels, which I observed relatively often in my own dataset.
Results:
Conclusions:
- Avg Pearsons r2 between raw STC values for Missing vs Interpolated was 0.996
- Avg Pearsons r2 between raw STC values for Original vs Missing was 0.959
- Avg Pearsons r2 between raw STC values for Original vs Interpolated was 0.955
Thus using direct source localization with missing channels without interpolation was on average 0.4% better correlated with the original data in the raw values for the source time series.
You were right that it was not much, but as Britta also pointed out the problem is smaller for minimum norm and other source localization methods might reveal a bigger effect.
The final code I used (beware it is not optimized and very memory intensive)
(I added parellel processing with 6 cores and thus needed 90gb RAM to run it)
### Source localization with or without interpolated channels
# Adapted from "EEG forward operator with template MRI" tutorial
# And "Source localization with MNE/dSPM/sLORETA/eLORETA" tutorial
# Libaries
import os.path as op
import numpy as np
import concurrent.futures # for multiprocessing
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import mne
from mne.datasets import eegbci
from mne.datasets import fetch_fsaverage
from mne.minimum_norm import make_inverse_operator, apply_inverse_raw
# Download fsaverage files
fs_dir = fetch_fsaverage(verbose=True)
subjects_dir = op.dirname(fs_dir)
# The files live in:
subject = 'fsaverage'
trans = 'fsaverage' # MNE has a built-in fsaverage transformation
src = op.join(fs_dir, 'bem', 'fsaverage-ico-5-src.fif')
bem = op.join(fs_dir, 'bem', 'fsaverage-5120-5120-5120-bem-sol.fif')
def load_sample(subject_no):
raw_fname, = eegbci.load_data(subject=subject_no, runs=[4])
raw = mne.io.read_raw_edf(raw_fname, preload=True)
# Clean channel names to be able to use a standard 1005 montage
new_names = dict(
(ch_name,
ch_name.rstrip('.').upper().replace('Z', 'z').replace('FP', 'Fp'))
for ch_name in raw.ch_names)
raw.rename_channels(new_names)
# Read and set the EEG electrode locations:
montage = mne.channels.make_standard_montage('standard_1005')
raw.set_montage(montage)
raw.set_eeg_reference("average", projection=True) # needed for inverse modeling
return raw
# Compute STC
def compute_stc(raw):
# Compute forward operator
fwd = mne.make_forward_solution(raw.info, trans=trans, src=src,
bem=bem, eeg=True, mindist=5.0, n_jobs=1)
# Compute noise covariance
noise_cov = mne.compute_raw_covariance(
raw, rank=None, verbose=True)
# Make inverse operator
inverse_operator = make_inverse_operator(
raw.info, fwd, noise_cov, loose=0.2, depth=0.8)
method = "MNE"
snr = 3.
lambda2 = 1. / snr ** 2
stc = apply_inverse_raw(raw, inverse_operator, lambda2,
method=method, pick_ori=None,
verbose=True)
return stc
# Select bad channels
bad_channels = ["FC6","C6", "CP6"] # 3 neighboring bad channels
# It is harder to interpolate and smooth regions with multiple bad channels
# compared to bad channels spread away from each other
def test_drop_versus_interpol(i):
Subject = i
# Load and compute ground truth
raw = load_sample(Subject)
stc = compute_stc(raw)
# Get data from subset of the vertices
data = stc.to_data_frame().iloc[::10,2::100].to_numpy()
del stc # free some space
# Repeat but with dropping bad channels
raw1 = load_sample(Subject)
# Set bad channels and drop them so they won't be used for invere modelling
raw1.info["bads"] = bad_channels
raw1.drop_channels(raw1.info["bads"])
stc1 = compute_stc(raw1)
data1 = stc1.to_data_frame().iloc[::10,2::100].to_numpy()
del stc1 # free some space
# Repeat but with interpolation of bad channels instead of dropping
raw2 = load_sample(Subject)
raw2.info["bads"] = bad_channels
raw2.interpolate_bads(reset_bads=True)
stc2 = compute_stc(raw2)
# Get data from subset of the vertices
data2 = stc2.to_data_frame().iloc[::10,2::100].to_numpy()
del stc2 # free some space
# Calculate Pearson correlation and plot scatter
r2_ground_drop = np.square(np.corrcoef(data.ravel(),data1.ravel())[0,1])
# plt.figure()
# plt.plot(data.ravel(),data1.ravel(),"o")
# plt.text(np.max(data)*0.7,np.max(data)*0.95,"r2 = {:.3F}".format(r2_ground_drop))
# plt.xlabel("STC raw (Ground truth)")
# plt.ylabel("STC raw (Dropped bad channels)")
r2_ground_interpol = np.square(np.corrcoef(data.ravel(),data2.ravel())[0,1])
# plt.figure()
# plt.plot(data.ravel(),data2.ravel(),"o")
# plt.text(np.max(data)*0.7,np.max(data)*0.95,"r2 = {:.3F}".format(r2_ground_interpol))
# plt.xlabel("STC raw (Ground truth)")
# plt.ylabel("STC raw (Interpolated bad channels)")
r2_drop_interpol = np.square(np.corrcoef(data1.ravel(),data2.ravel())[0,1])
res = [r2_ground_drop,r2_ground_interpol,r2_drop_interpol]
return i, res
n_subjects = 50
r2_tests = [0]*n_subjects
with concurrent.futures.ProcessPoolExecutor(max_workers=6) as executor:
for i, res in executor.map(test_drop_versus_interpol, range(1,n_subjects+1)):
r2_tests[i-1] = res
r2_df = pd.DataFrame(r2_tests)
r2_df.columns = ["Original_vs_Missing","Original_vs_Interpolated","Missing_vs_Interpolated"]
r2_df["Subject_ID"] = range(1,51)
r2_df.to_pickle("Miss_vs_interpol_prior_source_test.pkl")
r2_dfm = pd.melt(r2_df, id_vars = "Subject_ID",var_name="Comparison",value_name="Pearson_r2")
# Mean + SD with all points plot for group
g = sns.FacetGrid(data=r2_dfm,
margin_titles=True, height=4, aspect=0.7)
g = g.map(sns.stripplot,"Comparison", "Pearson_r2", "Comparison",
palette=sns.color_palette(),
dodge=0.45, jitter=0.25, alpha=0.2)
g.add_legend()
g = g.map(sns.pointplot,"Comparison", "Pearson_r2", "Comparison",
dodge=0.5, capsize=0.18, ci="sd", linestyles=["", "", ""],
markers=["o", "o", "o"], color="black")
plt.subplots_adjust(top=0.9, right=0.95, left=0.1)
g.fig.suptitle("Comparison of Missing channels versus Interpolated channels (Mean with SD)", fontsize=15)
g.set_axis_labels(x_var="Comparison", y_var="Pearson r2 for raw STC values")
print(r2_df.iloc[:,0:3].describe())
print(np.mean(r2_df["Original_vs_Missing"])/np.mean(r2_df["Original_vs_Interpolated"]))