Hi @datalw,
Is it somehow possible to get result of mne_connectivity.viz.plot_sensors_connectivity
in 2D version and give axes argument so that the output can be used as a subplot?
I am not aware of a way to do so this with a 3D figure. You could try to save the 3D figure as a 2D image, but without properly projecting it to a 2D surface there will be many overlapping sensor locations, especially around the edges.
However, if you have a channel montage for your data, you could use the 2D figure returned from mne.viz.plot_montage()
and then plot the connectivity between sensors on top of this, e.g.:
I generated this with the code below. You can just substitute in whatever all-to-all connectivity data you have:
import os.path as op
import mne
import numpy as np
from mne.datasets import sample
from matplotlib import pyplot as plt
from mne_connectivity import spectral_connectivity_epochs
from matplotlib import colors
# Read data
data_path = sample.data_path()
raw_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif")
event_fname = op.join(data_path, "MEG", "sample", "sample_audvis_filt-0-40_raw-eve.fif")
raw = mne.io.read_raw_fif(raw_fname).pick("eeg")
events = mne.read_events(event_fname)
# Create epochs
event_id, tmin, tmax = 3, -0.2, 1.5
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0))
# Compute connectivity for the frequency band containing the evoked response
fmin, fmax = 4.0, 9.0
tmin = 0.0
con = spectral_connectivity_epochs(
data=epochs, method="pli", fmin=fmin, fmax=fmax, faverage=True, tmin=tmin, n_jobs=1
)
########################################################################################
# Define connections to plot
min_distance = 0.05
n_con = 10
# Get the sensor locations
fig, axs = plt.subplots(1, 1, figsize=(10, 10))
fig = mne.viz.plot_montage(raw.get_montage(), axes=axs, show_names=False, show=False)
sens_loc = fig.axes[0].collections[0].get_offsets().data
# Get the strongest n_con connections
con_vals = con.get_data("dense")[:, :, 0]
threshold = np.sort(con_vals, axis=None)[-n_con]
ii, jj = np.where(con_vals >= threshold)
# Remove close connections
con_nodes = list()
con_val = list()
for i, j in zip(ii, jj):
if np.linalg.norm(sens_loc[i] - sens_loc[j]) > min_distance:
con_nodes.append((i, j))
con_val.append(con_vals[i, j])
con_val = np.array(con_val)
# Plot the connections
vmin = con_val.min()
vmax = con_val.max()
cmap = plt.get_cmap("viridis")
norm = colors.Normalize(vmin=vmin, vmax=vmax)
colours = cmap(norm(con_val))
for node_idx, (i, j) in enumerate(con_nodes):
axs.plot(
(sens_loc[i, 0], sens_loc[j, 0]),
(sens_loc[i, 1], sens_loc[j, 1]),
color=colours[node_idx],
)
fig.colorbar(
plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs, shrink=0.5, label="PLI (A.U.)"
)