Using autoreject to compute local threshes impacts the epochs dropped using a global threshold

Hello MNE community,

I am trying to use autoreject to drop bad epochs with the function get_rejection_threshold. I discovered recently the local threshes method called compute_thresholds (Plot channel-level thresholds β€” autoreject 0.2.1 documentation) and I do not manage to understand the behavior below:

Code 1:

epochs = mne.Epochs(raw, events, event_id=event_dict['audio'], picks=['eeg', 'eog', 'ecg'],
                                        tmin=tmin, tmax=tmax, reject=None,
                                        proj=True, baseline=(None, 0), preload=True)
reject = get_rejection_threshold(epochs, ch_types='eeg', decim=2)
epochs.drop_bad(reject=reject)

OUT: <Epochs | 69 events (all good), -0.2 - 0.798 sec, baseline [-0.2, 0] sec, ~17.5 MB, data loaded, β€˜4’: 69>

Code 2:

epochs = mne.Epochs(raw, events, event_id=event_dict['audio'], picks=['eeg', 'eog', 'ecg'],
                    tmin=tmin, tmax=tmax, reject=None,
                    proj=True, baseline=(None, 0), preload=True)
reject = get_rejection_threshold(epochs, ch_types='eeg', decim=2)

picks = mne.pick_types(epochs.info, meg=False, eeg=True, stim=False, ecg=False,
                       eog=False, exclude='bads')
local_rejection_threshes = compute_thresholds(epochs, picks=picks, method='random_search',
                                               augment=True, verbose='progressbar')

epochs.drop_bad(reject=reject)

OUT: <Epochs | 26 events (all good), -0.2 - 0.798 sec, baseline [-0.2, 0] sec, ~6.6 MB, data loaded, β€˜4’: 26>

Why is the computation of thresholds impacting (heavily) the number of epochs retained when using the same global rejection threshold?

Best,
Mathieu Scheltienne

Hi Mathieu,

Can you supply a fully reproducible script with the sample dataset so we can investigate further.

You should also supply random_state parameter so that the outputs are not stochastic

Mainak

@mainakjas You can find below code to reproduce my output. The .fif file to use is available for 1 week at this location. I will reupload it if requested.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import mne
import numpy as np

from autoreject import get_rejection_threshold, compute_thresholds

# Load
fname = r'path to .fif'
raw = mne.io.read_raw_fif(fname, preload=True)
raw.set_channel_types(mapping={'ECG':'ecg', 'EOG':'eog'}) # Mapping for ECG and EOG channels.
raw.set_montage('standard_1020')
raw.info['bads'] = ['O1', 'Oz', 'O2', 'Fp2', 'AF7', 'AF8', 'PO8', 'Fp1']

# Fix the TRIGGER channel
def brainvisionMarkersCh2StimCh(timearr):
    timearr[np.where(timearr==-1)] = 0
    return timearr
raw.apply_function(brainvisionMarkersCh2StimCh, picks=['stim'], dtype=None,
                   n_jobs=1, channel_wise=True, verbose=None)

# Events
event_dict = {'rest': 1, 'blink': 2, 'audio': 4}
event_duration_mapping = {1: 1, 2: 60, 4: 0.8}
events = mne.find_events(raw, stim_channel='TRIGGER')
    
# Filters - CAR
raw.set_eeg_reference(ref_channels='average', ch_type='eeg', projection=False) # bads are excluded
mne.io.Raw.filter(raw, l_freq=1., h_freq=40., picks=['ecg', 'eog', 'eeg'], method='iir', 
                  iir_params=dict(order=4, ftype='butter', output='sos'))
# Powerline noise: 50Hz for EU
raw.notch_filter(np.arange(50, 101, 50), picks=['ecg', 'eog'], filter_length='auto', phase='zero')

# EOG SSP Projection
# crop raw around blink paradigm, or audio stimulus paradigm is blink is not present
if event_dict['blink'] in (ev[2] for ev in events):
    event_start = (events[np.where(events[:, 2] == event_dict['blink'])][0][0]) / raw.info['sfreq']
    event_stop = event_start + event_duration_mapping[event_dict['blink']]
else:
    event_start = (events[np.where(events[:, 2] == event_dict['rest'])][0][0]) / raw.info['sfreq']
    event_stop = (events[np.where(events[:, 2] == event_dict['audio'])][-1][0]) / raw.info['sfreq'] + event_duration_mapping[event_dict['audio']]
eog_raw = raw.copy().crop(tmin=event_start, tmax=event_stop, include_tmax=True)
            
# bads are automatically excluded when set in info['bads']
eog_projs, _ = mne.preprocessing.compute_proj_eog(eog_raw, n_grad=0, n_mag=0, n_eeg=1, reject=None,
                                                  no_proj=True, n_jobs=1)

raw.add_proj(eog_projs)
raw.apply_proj()

# Fitler for N1-P2
mne.io.Raw.filter(raw, l_freq=1., h_freq=15., picks='eeg', method='iir', 
                  iir_params=dict(order=4, ftype='butter', output='sos'))

# Epochs
tmin = -0.2
tmax = 0.798
epochs = mne.Epochs(raw, events, event_id=event_dict['audio'], picks=['eeg', 'eog', 'ecg'],
                            tmin=tmin, tmax=tmax, reject=None,
                            proj=True, baseline=(None, 0), preload=True)

# Rejection
reject = get_rejection_threshold(epochs, ch_types='eeg', decim=2)

# (optional)
picks = mne.pick_types(epochs.info, meg=False, eeg=True, stim=False, ecg=False,
                                  eog=False, exclude='bads')
local_rejection_threshes = compute_thresholds(epochs, picks=picks, method='random_search',
                                              augment=True, verbose='progressbar')

# Apply rejection
epochs.drop_bad(reject=reject)

For the random seed, sure you can fix it. However, I did try this about 20 times and never got a different output, so in this case it doesn’t seem to make a difference. My current solution is to change the local threshes computation line with:

local_rejection_threshes = compute_thresholds(epochs.copy(), picks=picks, method='random_search',
                                              augment=True, verbose='progressbar')

I also noted that this behavior seems to be caused by the augment argument. If set to False, I do get 69 epochs despite running compute_thresholds. This argument is not very clear to me, and from the documentation, it doesn’t specify that it modifies the passed epoch instance.

Some insights on what augment does would be very helpful as at the moment I am not sure if I want it ON or OFF.

Hi Mathieu,

It turns out that autoreject was resetting the bad channels when compute_thresholds was called with augment=True.

The augment parameter does a leave-one-out-interpolation (across channels) that interpolates each channel from the other channels to generate twice the number of epochs than those that already existed. The reason is to be able to deal with channels which are globally bad or mostly bad. Since autoreject uses a cross-validation (across epochs) to determine the threshold, it needs a good mix of good and bad data to work. The leave-one-out-interpolation makes this happen for really bad channels.

Now what was happening is that the epochs.info[β€˜bads’] was being reset during this procedure so as to tell MNE which channel to interpolate. For now, I would suggest that a simple fix is to do:

bads_before = [epochs.info](http://epochs.info)['bads'].copy()
compute_thresholds(epochs, ....)
[epochs.info](http://epochs.info)['bads'] = bads_before

and this should make it work properly. I will work on patching this in autoreject as well. Hope that helps.

Mainak

@mainakjas Thank you for the reply. Very clear; and tested myself that restoring the bad channel does fix this. Maybe the documentation of autoreject could be improved a bit, especially around this resetting of the bad channels in the epoch instance?

Indeed, fix is on the way: https://github.com/autoreject/autoreject/pull/203 !

Autoreject should not meddle with the bad channels list.

Mainak

1 Like