Assume we have the following epochs
Epoch Genaration
import os.path
import autoreject
import mne
import numpy as np
import numpy.random
import pandas as pd
from mne_features.feature_extraction import extract_features
# config.DISABLE_JIT = True
numpy.random.seed(0)
dpath=os.path.join('/home/rpb/mne_data/MNE-sample-data/MEG/sample/','sample_audvis_filt-0-40_raw.fif')
# sample_data_raw_file='sample_audvis_filt-0-40_raw.fif'
raw = mne.io.read_raw_fif(dpath,preload=False)
picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False)
n_times = raw.times[-1]
nduration = int(n_times) # in seconds
epoch_size=6
ls_b = range(0, nduration, epoch_size)
tmin, tmax = 0,epoch_size
event_id = 1 # This is used to identify the events.
events = mne.make_fixed_length_events(raw, event_id, start=0, stop=None, duration=tmax)
df = pd.DataFrame(events, columns=['time_point', 'duration', 'event']).reset_index()
df.drop(columns=['event'],inplace=True)
df.rename({'index':'event'},inplace=True,axis=1)
df = df[['time_point', 'duration', 'event']]
v=df['event'].values.tolist()
event_ids= { str(idx) : idx for idx in v }
epochs = mne.Epochs(raw, events=df.to_numpy(), event_id=event_ids, baseline=None,
verbose=True, tmin=tmin, tmax=tmax, preload=True,picks=picks)
Epoch label
Assume we have the following label for each of the epoch
df = pd.DataFrame(np.random.randint(1, 10, len(epochs)),columns=['class'])
Clean the data using autoreject
ar = autoreject.AutoReject(n_interpolate=[1, 2, 3, 4], random_state=11,
n_jobs=-1, verbose=False,cv=5)
ar.fit(epochs[:10]) # fit on a few epochs to save time
epochs_ar, reject_log = ar.transform(epochs, return_log=True)
Get the univariate
features_drops = extract_features(epochs_ar.get_data(), raw.info["sfreq"],
selected_funcs=['skewness'], return_as_df=True,
ch_names=epochs.ch_names,n_jobs=1)
For simplicity, we consider only two features
features_drops=features_drops.iloc[:,0:1]
Remap the epoch label to the mne_features
output
features_drops['dlabel']=df.loc[[not bool(x) for x in epochs_ar.drop_log], 'class'].values
Sanity Check
We can visually check whether the above step is correct by comparing with the original epochs features below.
features_ = extract_features(epochs.get_data(), raw.info["sfreq"],
selected_funcs=['skewness'], return_as_df=True,
ch_names=epochs.ch_names,n_jobs=1)
## For simplicity, just consider two features
features_=features_.iloc[:,0:1]
features_['dlabel']=df.values
## Reference for cross check
drop_idx=[n for n, dl in enumerate(epochs_ar.drop_log) if len(dl)]
REMARK:
I purposely put this here for my future reference. However, should other have better suggestion, please share below.