How to subplot figure from plot_connectivity_circle

The objective is to create a subplot for list of figures produced from plot_connectivity_circle as
shown below

May I know whether there is build-in approach with mne?

The 3 figures were produced using the code below, and listed under the all_figure

import numpy as np
import mne
from mne.connectivity import spectral_connectivity
from mne.viz import circular_layout, plot_connectivity_circle
import matplotlib.pyplot as plt



def generate_conn():
    # Generate data

    label_names = ['FP1', 'FP2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4',
                   'T3', 'T4', 'O1', 'O2']

    np.random.seed ( 42 )
    n_epochs = 5
    n_channels = len(label_names)
    n_times = 1000 
    data = np.random.rand ( n_epochs, n_channels, n_times )
    # Set sampling freq
    sfreq = 250  # A reasonable random choice


    # 10Hz sinus waves with random phase differences in each channel and epoch
    # Generate 10Hz sinus waves to show difference between connectivity
    # over time and over trials. Here we expect con over time = 1
    for i in range ( n_epochs ):
            for c in range ( n_channels ):
                wave_freq = 10
                epoch_len = n_times / sfreq
                # Introduce random phase for each channel
                phase = np.random.rand ( 1 ) * 10
                # Generate sinus wave
                x = np.linspace ( -wave_freq * epoch_len * np.pi + phase,
                                  wave_freq * epoch_len * np.pi + phase, n_times )
                data [i, c] = np.squeeze ( np.sin ( x ) )



    info = mne.create_info(ch_names=label_names,
                           ch_types=['eeg'] * len(label_names),
                           sfreq=sfreq)


    epochs = mne.EpochsArray(data, info)

    # Define freq bands
    Freq_Bands = {"delta": [1.25, 4.0],
                  "theta": [4.0, 8.0],
                  "alpha": [8.0, 13.0],
                  "beta": [13.0, 30.0],
                  "gamma": [30.0, 49.0]}


    n_freq_bands = len ( Freq_Bands )
    # Convert to tuples for the mne function
    fmin = tuple ( [list ( Freq_Bands.values () ) [f] [0] for f in range ( len ( Freq_Bands ) )] )
    fmax = tuple ( [list ( Freq_Bands.values () ) [f] [1] for f in range ( len ( Freq_Bands ) )] )

    # Connectivity methods
    connectivity_methods = ["plv"]
    n_con_methods = len ( connectivity_methods )

    # # Calculate PLV and wPLI - the MNE python implementation is over trials
    con, freqs, times, n_epochs, n_tapers = spectral_connectivity (
        epochs, method=connectivity_methods,
        mode="multitaper", sfreq=sfreq, fmin=fmin, fmax=fmax,
        faverage=True, verbose=0 )
    all_ch=epochs.ch_names

    return con,all_ch

def plot_conn(conmat,all_ch):
    lh_labels = ['FP1', 'F7', 'F3', 'C3', 'T3', 'O1']
    rh_labels = ['FP2', 'F8', 'F4', 'C4', 'T4', 'O2']
    node_order = lh_labels +rh_labels # Is this order tally with the con arrangement?
    node_angles = circular_layout ( all_ch, node_order, start_pos=90,
                                    group_boundaries=[0, len ( all_ch) // 2] )


    fig = plt.figure ( num=None, figsize=(8, 8), facecolor='black' )
    fig=plot_connectivity_circle ( conmat, all_ch, n_lines=300,
                               node_angles=node_angles,
                               title='All-to-All Connectivity '
                                     'Condition (PLI)_Delta', fig=fig )
    return fig


con,all_ch=generate_conn()
all_fig=[]
for idx in range (0,3):
    conmat = con [:, :, idx]
    fig=plot_conn(conmat,all_ch)
    all_fig.append(fig)

Appreciate for any hint

One of the dirty solution is by transforming each of the Figure as Numpy array, and stack the array either vertically or horizontally.

  1. Generate the Numpy array
  • Redraw the plot_connectivity_circle output using canvas.draw ()
  • Get the array form by transforming the new redraw image using np.frombuffer


    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    canvas = FigureCanvas ( fig )
    plot_connectivity_circle ( conmat, all_ch, n_lines=300,
                                    node_angles=node_angles,
                                    title=f'All-to-All Connectivity_ band_{bands}', fig=fig )
    
    canvas.draw ()
    s, (width, height) = canvas.print_to_buffer ()
    im0 = np.frombuffer ( s, np.uint8 ).reshape ( (height, width, 4) )
  1. Create a subplot by stacking the array

np.hstack ( all_fig ) # all_fig is a list of array

The complete code as is below:

import mne
from mne.connectivity import spectral_connectivity
from mne.viz import circular_layout, plot_connectivity_circle
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

def generate_conn ():
    # Generate data

    label_names = ['FP1', 'FP2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4',
                   'T3', 'T4', 'O1', 'O2']

    np.random.seed ( 42 )
    n_epochs = 5
    n_channels = len ( label_names )
    n_times = 1000
    data = np.random.rand ( n_epochs, n_channels, n_times )
    # Set sampling freq
    sfreq = 250  # A reasonable random choice

    # 10Hz sinus waves with random phase differences in each channel and epoch
    # Generate 10Hz sinus waves to show difference between connectivity
    # over time and over trials. Here we expect con over time = 1
    for i in range ( n_epochs ):
        for c in range ( n_channels ):
            wave_freq = 10
            epoch_len = n_times / sfreq
            # Introduce random phase for each channel
            phase = np.random.rand ( 1 ) * 10
            # Generate sinus wave
            x = np.linspace ( -wave_freq * epoch_len * np.pi + phase,
                              wave_freq * epoch_len * np.pi + phase, n_times )
            data [i, c] = np.squeeze ( np.sin ( x ) )

    info = mne.create_info ( ch_names=label_names,
                             ch_types=['eeg'] * len ( label_names ),
                             sfreq=sfreq )

    epochs = mne.EpochsArray ( data, info )

    # Define freq bands
    Freq_Bands = {"delta": [1.25, 4.0],
                  "theta": [4.0, 8.0],
                  "alpha": [8.0, 13.0],
                  "beta": [13.0, 30.0],
                  "gamma": [30.0, 49.0]}

    n_freq_bands = len ( Freq_Bands )
    # Convert to tuples for the mne function
    fmin = tuple ( [list ( Freq_Bands.values () ) [f] [0] for f in range ( len ( Freq_Bands ) )] )
    fmax = tuple ( [list ( Freq_Bands.values () ) [f] [1] for f in range ( len ( Freq_Bands ) )] )

    # Connectivity methods
    connectivity_methods = ["plv"]
    n_con_methods = len ( connectivity_methods )

    # # Calculate PLV and wPLI - the MNE python implementation is over trials
    con, freqs, times, n_epochs, n_tapers = spectral_connectivity (
        epochs, method=connectivity_methods,
        mode="multitaper", sfreq=sfreq, fmin=fmin, fmax=fmax,
        faverage=True, verbose=0 )
    all_ch = epochs.ch_names

    return con, all_ch


def plot_conn (conmat, all_ch, idx, bands):
    lh_labels = ['FP1', 'F7', 'F3', 'C3', 'T3', 'O1']
    rh_labels = ['FP2', 'F8', 'F4', 'C4', 'T4', 'O2']
    node_order = lh_labels + rh_labels  # Is this order tally with the con arrangement?
    node_angles = circular_layout ( all_ch, node_order, start_pos=90,
                                    group_boundaries=[0, len ( all_ch ) // 2] )

    fig = plt.figure ( num=None, figsize=(8, 8), facecolor='black' )
    

    canvas = FigureCanvas ( fig )
    plot_connectivity_circle ( conmat, all_ch, n_lines=300,
                                    node_angles=node_angles,
                                    title=f'All-to-All Connectivity_ band_{bands}', fig=fig )
    
    canvas.draw ()
    s, (width, height) = canvas.print_to_buffer ()
    im0 = np.frombuffer ( s, np.uint8 ).reshape ( (height, width, 4) )
    return im0


con, all_ch = generate_conn ()

all_fig = [plot_conn ( con [:, :, idx], all_ch, idx, band ) for idx, band in enumerate ( ["delta", "theta", "alpha"] )]

SUBPLOT = np.hstack ( all_fig )



plt.imsave ( 'myimage.png', SUBPLOT )