Reject trials with an amplitude greater than 3 times the standard deviation

Hi all,

I want to perform trial rejection based on removing trials which have an any amplitude that exceed the mean + 3*std across all trials and channels, what I’m trying so far is:´below, where raw is a single subjects raw fif I’ve already loaded

# Get mean and std across all channels
    data = raw.get_data(picks=esg_chans)  # (n_channels, n_times) nd_array
    mean = np.mean(data, axis=tuple([0, 1]))
    std = np.std(data, axis=tuple([0, 1]))

    threshold = abs((mean + 3*std)*1e6)

    df = epochs.to_data_frame()  # Automatically scales values by 1e6 - to uV
    df = df.drop(columns=["ECG", "TH6"])
    drop_jump = []
    for i in set(df.epoch):
        sub_df = df[df.epoch == i]
        amps = sub_df.iloc[:, 3:]  # Extract all value amplitudes for this epoch
        check = amps.ge(threshold).values.any()
        drop_jump.append(check)
    epochs.drop(drop_jump, reason='3*std')  # True values removed, false values retained

but, it’s dropping way more trials than I would expect (only 11 out of 2000 trials are retained), so I’m likely doing something wrong, but not sure what!

Any tips welcome :slight_smile:

  • MNE version: e.g. 1.2.0
  • operating system: Debian GNU/Linux 11
1 Like

Hello,

I’m not following your code completely, but if you want to create epochs from a continuous data (raw) recording and automatically drop epochs where the peak-to-peak amplitude exceeds a threshold, you have 2 options:

  • Provide the argument reject when creating the epochs
  • Annotate the continuous recording to mark segments with large amplitude as bads

You can find below an example with both:

import numpy as np
from mne import Epochs, create_info, find_events
from mne.io import RawArray
from mne.preprocessing import annotate_amplitude


info = create_info(["EEG 01", "EEG 02", "EEG 03"], sfreq=512, ch_types="eeg")
data = np.random.randn(3, 30720)  # 60 seconds of continuous data
data *= 1e-5  # scaling to get at least the correct order of magnitude
raw = RawArray(data, info)

# add an event channel with fake events every 1 seconds
info = create_info(["STI"], sfreq=raw.info["sfreq"], ch_types="stim")
data_stim = np.zeros(shape=(1, len(raw.times)))
data_stim[0, 0::512] = 1
stim = RawArray(data_stim, info)
raw.add_channels([stim], force_update_info=True)

#%% Option 1: PTP rejection threshold
# create epochs with a peak-to-peak amplitude rejection criterium
threshold = 60 * 1e-6  # 60 uV peak-to-peak amplitude within an epoch
events = find_events(raw, stim_channel="STI")
epochs = Epochs(
    raw, 
    events, 
    event_id=dict(test=1), 
    tmin=0, 
    tmax=0.5, 
    picks="eeg", 
    reject=dict(eeg=threshold),
    reject_by_annotation=False,  # there are no annotations anyway
    baseline=None,
    preload=True,
)

#%% Option 2: Annotation of bad segments based on PTP amplitude
# create annotations that will mark the segments of continuous data exceeding
# a threshold
threshold = 30 * 1e-6  # 30 uV maximum amplitude
annotations, bads = annotate_amplitude(
    raw, 
    peak=dict(eeg=threshold),
    picks="eeg",
)
raw.set_annotations(annotations)
# N.B: If you already had annotations in your recording and want to keep them
# alongside the added BAD_peak annotations, you need to pass both:
# raw.set_annotations(raw.annotations + annotations)

# then create epochs and reject by annotations
epochs = Epochs(
    raw, 
    events, 
    event_id=dict(test=1), 
    tmin=0, 
    tmax=0.5, 
    picks="eeg", 
    reject=None,
    reject_by_annotation=True,
    baseline=None,
    preload=True,
)

Note that they do not do exactly the same thing. After adding the annotations, you can use raw.plot() to display the data and the annotations, it should make it clear which part of the signal has been rejected.

Mathieu

Thanks for this! I think option 2 is essentially what I want - but rather than having a set threshold, it is instead threshold = max(mean+3std, mean-3std) across all trials and all channels :slight_smile:

Ok, so having dove into the docs a bit more - the problem I have is that MNE wants to work based on peak-to-peak values, but I want to instead annotate any extreme values based solely on amplitude, and not ptp

You could retrieve the data array with data = raw.get_data(), look for indices where the value is above your criteria with np.where, group those indices by consecutiveness (python - Detecting consecutive integers in a list - Stack Overflow, to determine the onset and duration of each anomaly exceeding your threshold), convert those indices to time (either via the sampling frequency or by indexing in raw.times), and finally create annotations:

from mne import Annotations

annotations = Annotations(onset, duration, "BAD")
raw.set_annotations(raw.annotations + annotations)

Yes I think this is the way forward - I think adding as annotations is a much better way than what I was trying to do by converting to epochs first :slight_smile: Thank you for your help!

cross-reference to annotate_amplitude for a specific time window - #13 by FranziMNE where the question of rejecting on PTP amplitude vs. absolute amplitude was already discussed. Suggestion there was to allow reject to take a callable function; happy to review a PR if anyone gets excited about implementing it.

Here, I used the solution suggested by @mscheltienne and it seems to work quite nicely:

badPoints = np.where(absMinMaxAllChan > amplitudeThreshold, True, False)
    # Returns logical mask - True where condition is true, false where condition is false

    # Get timing of points threshold is exceeded
    sample_indices = np.argwhere(badPoints)
    if sample_indices.size != 0:
        sample_indices = sample_indices.reshape(-1)
        # Add bad amplitude events as annotations
        bad_amp_events = [x / sampling_rate for x in sample_indices]  # Divide by sampling rate to make times
        bad_amp_events -= 0.0025
        annotations = Annotations(bad_amp_events, duration=0.005, description="BAD_amp")
        # Will be 2.5ms before and 2.5ms after the detected bad amplitude
        raw.set_annotations(raw.annotations + annotations)
1 Like