Channels colors in PSD plot()

  • MNE version: 1.3.1
  • mne-qt-browser: 0.4.0
  • operating system: Windows 10

Hi,
I would like to know if it is possible to change the channels colors in PSD plot (raw.compute_psd().plot()). I searched for some methods but couldn’t find anything to do it easily.
The spatial_colors argument is nice, but sometimes for some channels it assigns unreadable colors (like light cyan).

Thanks for your help !

from the docstring of the plot method:

color : str | tuple

    A matplotlib-compatible color to use. Has no effect when spatial_colors=True.

So if you’re OK with all channels being the same color, you can pass any color you want. If you want control over the color of each channel separately, I think currently that is not possible.

Thanks for your reply.
I didn’t precise, but yes the idea is to control the color of each channel separately.
I will try to find a solution.

You’ll probably need to do the plot manually with matplotlib commands. I think your options are:

  1. a for-loop over zip(channels, colors) with a call to axes.plot() in the loop
  2. use matplotlib LineCollection Plotting multiple lines with a LineCollection — Matplotlib 3.7.1 documentation
1 Like

I finally tested two solutions :

  1. One directly by modificate a source librairy file and just change colors which are too ligth (and keep MNE colors)

files PATH : …\eenv\lib\site-packages\mne\viz\evoked.py (line 173)

def _rgb(x, y, z):
    """Transform x, y, z values into RGB colors."""
    rgb = np.array([x, y, z]).T
    rgb -= np.nanmin(rgb, 0)
    rgb /= np.maximum(np.nanmax(rgb, 0), 1e-16)  # avoid div by zero
    # Modification to avoid invisible colors :
    # Checking colors which have rgb values near 1 (ligth) for reduce thoses rgb values
    rgb[rgb.sum(axis=1)>2.5] = rgb[rgb.sum(axis=1)>2.5] - 0.3
    return rgb
  1. Like you suggested, by modificate colors with matplotlib, after plot (tested with 4 and 32 channels) :
spect = raw.compute_psd(fmin=0, fmax=50, n_fft=int(raw.info['sfreq']*5))
fig = spect.plot()

channels = raw.ch_names
colors = ['red', 'green', 'blue', 'black'] * int(len(channels)/4)

# Get PSD subplot from figure
fig_psd = fig.axes[0]
# Get Head with sensors subplot from figure
fig_head = fig.axes[1]
# Set sensor color on Head subplot (get the collection)
fig_head.get_children()[0].set_color(colors)
# Set sensor width on Head subplot (get the collection)
fig_head.get_children()[0].set_linewidth(5.0)

# Loop for each channel/color
for nl, (chan, color) in enumerate(zip(channels, colors)):
    # Get the channel line (use nl+2 because of axis)
    line = fig_psd.get_lines()[nl+2]
    # Set color to this line
    line.set_color(color)
    if nl<4:
        # Set color to head lines subplot just for fun
        fig_head.get_lines()[nl].set_color('green')

Thanks for your help.