This is extension regarding my old post about Return Spectral Connectivity of Each Epoch. Since Im unable to add an answer on that post, I will create new one instead.
Spectral Connectivity of Each Epoch can be found under the function connectivity_per_epochs ()
.For comparison, the typical MNE python implementation over trials is included which can be accessed via connectivity_average_epochs ()
import mne
from mne.connectivity import spectral_connectivity
from mne.viz import circular_layout, plot_connectivity_circle
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
def data_setting ():
# Formating setting
label_names = ['FP1', 'FP2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4',
'T3', 'T4', 'O1', 'O2']
Freq_Bands = {"delta": [1.25, 4.0],"theta": [4.0, 8.0],
"alpha": [8.0, 13.0],"beta": [13.0, 30.0],"gamma": [30.0, 49.0]}
lh_labels = ['FP1', 'F7', 'F3', 'C3', 'T3', 'O1']
rh_labels = ['FP2', 'F8', 'F4', 'C4', 'T4', 'O2']
connectivity_methods = ["coh"]
set_data = {'label_names': label_names, 'Freq_Bands': Freq_Bands, 'lh_labels': lh_labels,
'rh_labels': rh_labels, 'connectivity_methods': connectivity_methods}
return set_data
def create_epochs (set_data):
# Generate data
np.random.seed ( 42 )
n_epochs = 5
n_channels = len ( set_data ['label_names'] )
n_times = 1000
data = np.random.rand ( n_epochs, n_channels, n_times )
# Set sampling freq
sfreq = 256 # A reasonable random choice
# 10Hz sinus waves with random phase differences in each channel and epoch
# Generate 10Hz sinus waves to show difference between connectivity
# over time and over trials. Here we expect con over time = 1
for i in range ( n_epochs ):
for c in range ( n_channels ):
wave_freq = 10
epoch_len = n_times / sfreq
# Introduce random phase for each channel
phase = np.random.rand ( 1 ) * 10
# Generate sinus wave
x = np.linspace ( -wave_freq * epoch_len * np.pi + phase,
wave_freq * epoch_len * np.pi + phase, n_times )
data [i, c] = np.squeeze ( np.sin ( x ) )
return mne.EpochsArray ( data, mne.create_info ( ch_names=set_data ['label_names'],
ch_types=['eeg'] * len ( set_data ['label_names'] ),
sfreq=sfreq ) )
def define_freq_bands (set_data):
# Define freq bands
return tuple (
[list ( set_data ['Freq_Bands'].values () ) [f] [0] for f in
range ( len ( set_data ['Freq_Bands'] ) )] ), tuple (
[list ( set_data ['Freq_Bands'].values () ) [f] [1] for f in range ( len ( set_data ['Freq_Bands'] ) )] )
def _connectivity_average_epochs (epochs=None, set_data=None):
# # Calculate PLV a - the MNE python implementation is over trials
fmin, fmax = define_freq_bands ( set_data )
con, freqs, times, n_epochs, n_tapers = spectral_connectivity (
epochs, method=set_data ['connectivity_methods'],
mode="multitaper", sfreq=epochs.info ['sfreq'], fmin=fmin, fmax=fmax,
faverage=False, verbose=0 )
return con, epochs.ch_names
def _connectivity_per_epochs (epochs, set_data):
# # Calculate PLV for each trial
fmin, fmax = define_freq_bands ( set_data )
all_con = []
for epoch in epochs:
con_each, freqs, times, n_epochs, n_tapers = spectral_connectivity (
[epoch], method=set_data ['connectivity_methods'],
mode="multitaper", sfreq=epochs.info ['sfreq'], fmin=fmin, fmax=fmax,
faverage=False, verbose=0 )
all_con.append ( con_each )
return all_con, epochs.ch_names
def plot_conn (conmat, all_ch, idx, bands, set_data):
# Generate circular graph
node_order = set_data ['lh_labels'] + set_data ['rh_labels']
node_angles = circular_layout ( all_ch, node_order, start_pos=90,
group_boundaries=[0, len ( all_ch ) // 2] )
fig = plt.figure ( num=None, figsize=(8, 8), facecolor='black' )
canvas = FigureCanvas ( fig )
plot_connectivity_circle ( conmat, all_ch, n_lines=300,
node_angles=node_angles,
title=f'All-to-All Connectivity_ band_{bands}', fig=fig )
canvas.draw ()
s, (width, height) = canvas.print_to_buffer ()
im0 = np.frombuffer ( s, np.uint8 ).reshape ( (height, width, 4) )
return im0
set_data = data_setting ()
def connectivity_average_epochs ():
con, all_ch = _connectivity_average_epochs ( epochs=create_epochs ( set_data ), set_data=set_data )
all_fig = [plot_conn ( con [:, :, idx], all_ch, idx, band, set_data ) for idx, band in
enumerate ( ["delta", "theta", "alpha", "beta", "gamma"] )]
plt.imsave ( 'connectivity_average_epochs.png', np.hstack ( all_fig ) )
def connectivity_per_epochs ():
con_per_epochs, all_ch = _connectivity_per_epochs ( epochs=create_epochs ( set_data ), set_data=set_data )
for idx_epoch, con in enumerate ( con_per_epochs ):
all_fig = [plot_conn ( con [:, :, idx], all_ch, idx, band, set_data ) for idx, band in
enumerate ( ["delta", "theta", "alpha", "beta", "gamma"] )]
plt.imsave ( f'epoch_{idx_epoch}_per_epoch_connectivity.png', np.hstack ( all_fig ) )
connectivity_average_epochs ()
connectivity_per_epochs ()
This will produced
connectivity_average_epochs
The connectivity for epoch index 0,1,2,3,4 are as below
epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
Should this step and output is correct, maybe either of the expert can confirm and close this thread.
Related discussion: