Advantages of epochs.average().compute_psd() vs. epochs.compute_psd().average()

  • MNE version 1.3.1
  • Ubuntu 20.04 running in WSL vscode Jupyter

Hello! I am experimenting with spectral analysis and was wondering if anyone could provide some insight into the advantages of averaging across epochs and then computing PSD vs. computing PSD then averaging across epochs.


They appear to have roughly the same shape, but averaging across epochs and then computing PSD appears much noisier than computing PSD first, and I am curious about the advantages of each method.

Thanks!
Jacob

2 Likes

Hi,

In the first case you would answer the question “Is there a certain frequency content (aka power in this case) present across (most of) the trials?” whereas in the second “What is the frequency content of the average signal across trials?”. Generally I’d say most of the times the former is more interesting.

Consider this extremely simplified example with only 2 signals, 20 Hz sinusoids with a pi worth of a phase shift between them. The Fourier Transform (not actually what compute_psd() above does) is a linear transformation. However the FT power (see abs) isn’t. Hence, the power of each signal when consider separately (and therefore the average of the two) will give you a 20 Hz peak, but first averaging the two signals will result in a signal with no power, exactly because of the phase shift!

Here is a sample code snippet to try for yourself (sorry I can’t upload the figure for some reason):

T = 5
dt = 0.001
time = np.linspace(0,T,int(T/dt))

f = 20
w = 2. * np.pi * f

s1 = np.sin(w * time)
s2 = np.sin(w * time + np.pi)
s3 = np.mean(np.vstack((s1, s2)), axis=0)

f1 = np.fft.fft(s1)
f2 = np.fft.fft(s2)
f3 = np.fft.fft(s3)

freq = np.fft.fftfreq(time.shape[-1])

fig = plt.figure(figsize=(15,15))
plt.subplot(4,1,1)
plt.plot(time, s1, time, s2, time, s3)
plt.subplot(4,1,2)
plt.plot(freq, f1.real, freq, f2.real, freq, f3.real)
plt.subplot(4,1,3)
plt.plot(freq, f1.imag, freq, f2.imag, freq, f3.imag)
plt.subplot(4,1,4)
plt.plot(freq, np.mean((np.abs(f1), np.abs(f2)), axis=0), freq, np.abs(f3))

Hope this helps!

5 Likes