Looping in a dataloader of edf files

- MNE version: 1.6.0
- operating system:Ubuntu 22.04.3 LTS
When looping through the train dataloader it takes 24 min to loop to the entire data which is so expensive if I want to increase the number of epochs. Here is my custom dataloader code.

class TUH_Dataset(Dataset):
    '''def __init__(self, dataset, t_max, mode, preprocess):
        # mal poglej arguments da das kot dictionary in potem filas values
        self.dataset = dataset
        self.t_max = t_max
        self.mode = mode
        self.preprocess = preprocess'''
    def __init__(self, params):
        self.dataset = params.get("dataset_path")
        self.t_max = params.get("t_max")
        self.mode = params.get("mode")
        self.preprocess = params.get("preprocess")
        self.high_pass_freq = params.get("high_pass_freq")
        self.notch_freq = params.get("notch_freq")
        self.resample_freq = params.get("resample_freq")
    
        if self.mode == 'train':
            self.train_normal = os.path.join(self.dataset, f'train/normal/01_tcp_ar/')
            self.train_files_normal = sorted(glob.glob(os.path.join(self.train_normal, '*.edf')))
            self.train_abnormal = os.path.join(self.dataset, f'train/abnormal/01_tcp_ar/')
            self.train_files_abnormal = sorted(glob.glob(os.path.join(self.train_abnormal, '*.edf')))
            self.all = self.train_files_normal + self.train_files_abnormal
            # 0 for normal, 1 for abnormal
            self.labels = [None] * (len(self.train_files_abnormal) + len(self.train_files_normal))
            self.labels[0:len(self.train_files_normal)] = [0] * len(self.train_files_normal)
            self.labels[len(self.train_files_normal):len(self.train_files_abnormal)+1] = [1] * len(self.train_files_abnormal)
        
        elif self.mode == 'test':
            self.test_normal = os.path.join(self.dataset, f'eval/normal/01_tcp_ar/')
            self.test_files_normal = sorted(glob.glob(os.path.join(self.test_normal, '*.edf')))
            self.test_abnormal = os.path.join(self.dataset, f'eval/abnormal/01_tcp_ar/')
            self.test_files_abnormal = sorted(glob.glob(os.path.join(self.test_abnormal, '*.edf')))
            self.all = self.test_files_normal + self.test_files_abnormal
            # 0 for normal, 1 for abnormal
            self.labels = [None] * (len(self.test_files_abnormal) + len(self.test_files_normal))
            self.labels[0:len(self.test_files_normal)] = [0] * len(self.test_files_normal)
            self.labels[len(self.test_files_normal):len(self.test_files_abnormal)+1] = [1] * len(self.test_files_abnormal)

    def __getitem__(self, idx):
        if self.mode == 'train' and self.preprocess == 'no':
            signal = mne.io.read_raw_edf(input_fname=self.all[idx], preload=True, verbose=False)
            # crop the signal to which second you want to be trained and tested
            signal = signal.crop(tmin=0, tmax=self.t_max, include_tmax=False)
            # (n, m) --> n - number of EEG channels, m - number of samples
            signal_values = signal.get_data()

        elif self.mode == 'train' and self.preprocess == 'yes':
            signal = mne.io.read_raw_edf(input_fname=self.all[idx], preload=True, verbose=False)
            signal = signal.crop(tmin=0, tmax=self.t_max, include_tmax=False)
            obj = Preprocess(self.high_pass_freq, self.notch_freq, self.resample_freq)
            filtered_data = obj.filter(signal)
            tcp_montage_filtered = obj.montage_technique(filtered_data)
            signal_resampled = obj.resample(tcp_montage_filtered)
            signal_values = signal_resampled.get_data()
        
        elif self.mode == 'test' and self.preprocess == 'yes':
            signal = mne.io.read_raw_edf(self.all[idx], preload=True, verbose=False)
            signal = signal.crop(tmin=0, tmax=self.t_max, include_tmax=False)
            obj = Preprocess(1, 60, 250)
            filtered_data = obj.filter(signal)
            tcp_montage_filtered = obj.montage_technique(filtered_data)
            signal_resampled = obj.resample(tcp_montage_filtered)
            signal_values = signal_resampled.get_data()
            
        return signal_values, float(self.labels[idx]), self.all[idx]

    def __len__(self):
        if self.mode == 'train':
            return len(self.all)
        elif self.mode == 'test':
            return len(self.all)

What I do in preprocess is I filter the data, set to double banana montage and resample all files to 250 Hz (since the dataset I have, edf files are not of the same frequency). Here is the code for the preprocess module:

class Preprocess:
    
    def __init__(self, high_pass_freq, notch_freq, resample_freq):
        self.high_pass_freq = high_pass_freq
        self.notch_freq = notch_freq
        self.resample_freq = resample_freq
    
    def filter(self, data):
        # fourth-order Butterworth high-pass filter
        raw_highpass = data.copy().filter(l_freq=self.high_pass_freq, h_freq=None, method='iir', verbose=False)
        raw_notch = raw_highpass.copy().notch_filter(freqs=self.notch_freq, verbose=False)

        return raw_notch
    
    def montage_technique(self, data):
        # TODO: to je hardcodano mal premisli kako bi se naredilo da ni hardcodano
        electrode_names_standard = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Cz']
        electrode_names = ['EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF', 'EEG C4-REF', 'EEG P3-REF', 
                   'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF', 'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 
                   'EEG T5-REF', 'EEG T6-REF', 'EEG CZ-REF']
        name_mapping = dict(zip(electrode_names, electrode_names_standard))
        # rename the channel to have shorter format
        data.rename_channels(name_mapping)
        # TCP montage, 20 channels, the pair for differencing channels
        bipolar_pairs = [('Fp1', 'F7'), ('F7', 'T3'), ('T3', 'T5'), ('T5', 'O1'), ('T3', 'C3'), ('C3', 'Cz'), ('Fp1', 'F3'),
                         ('F3', 'C3'), ('C3', 'P3'), ('P3', 'O1'), ('Fp2', 'F8'), ('F8', 'F4'), ('T4', 'T6'), ('T6', 'O2'),
                         ('T4', 'C4'), ('Cz', 'C4'), ('Fp2', 'F4'), ('F4', 'C4'), ('C4', 'P4'), ('P4', 'O2')]
        
        new_ch_data = []
        new_ch_names = []

        for pair in bipolar_pairs:
            ch1, ch2 = pair
            ch1_idx, ch2_idx = data.ch_names.index(ch1), data.ch_names.index(ch2)
            # data contains two arrays, first is the signal values, second the time in seconds
            # mozes data.times i posle data.pick_channels(channel) izmeni posle mos
            data_signal_1, time = data[ch1_idx, 0:len(data)]
            data_signal_2, time = data[ch2_idx, 0:len(data)]
            # apply differencing and ravel it to get the correct shape for numpy array
            diff_data = (data_signal_1 - data_signal_2).ravel()
            new_ch_data.append(diff_data)
            new_ch_names.append(f'{ch1}-{ch2}')
        
        new_ch_data = np.array(object=new_ch_data)
        new_info = mne.create_info(new_ch_names, sfreq=data.info['sfreq'], ch_types='eeg', verbose=False)
        new_raw = mne.io.RawArray(new_ch_data, new_info, verbose=False)

        return new_raw

    def resample(self, montage):
        # all the recordings need to have sampling frequency of 250 Hz
        raw_resampled = montage.copy().resample(sfreq=250, verbose=False)
        
        return raw_resampled

Looping through dataloader:

dataset_path = r'/home/iva/Desktop/maborg/Zacasno/davidsusic/EEG/v3.0.0/edf'
from models import *
dict_params = {"dataset_path": dataset_path, "t_max": 60, "mode": "train", "preprocess": "yes"}
if dict_params.get("preprocess") == 'yes':
    dict_params['high_pass_freq'] = int(input("Choose the high pass frequency: "))
    dict_params['notch_freq'] = int(input("Choose the notch frequency: "))
    dict_params['resample_freq'] = int(input("Choose the resampling frequency: "))
train_set = TUH_Dataset(dict_params)
train_dataloader = DataLoader(train_set, shuffle=True, batch_size=64)
for i, data in enumerate(tqdm(train_dataloader)):
        #inputs, labels, edf_names = data

The most critical part when I do cprofile in python is filtering and reading the mne signal (the raw signal is (30, 317000), where 30 is the channels and 317000 the number of samples), since I am preloading the data to memory to filter and resample I cant use num_workers in DataLoader, all memory is used when I am preloading the files, just to mention cropped all edf files in train_dataloader have the size of (20, 15000). Any ideas of how to optimize this?

Hello, try setting preload=False, then crop, then call load_data() before filtering.

Best wishes,
Richard

RuntimeWarning: Loading an EDF with mixed sampling frequencies and preload=False will result in edge artifacts. It is recommended to use preload=True.See also When loading EDF+ with different sampling frequencies, don't resample blockwise · Issue #10635 · mne-tools/mne-python · GitHub. I got this warning.