Exemplo n.º 1
0
def test_chpi_adjust():
    """Test cHPI logging and adjustment"""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield='yes')
    with catch_logging() as log:
        _get_hpi_info(raw.info, adjust=True, verbose='debug')

    # Ran MaxFilter (with -list, -v, -movecomp, etc.), and got:
    msg = ['HPIFIT: 5 coils digitized in order 5 1 4 3 2',
           'HPIFIT: 3 coils accepted: 1 2 4',
           'Hpi coil moments (3 5):',
           '2.08542e-15 -1.52486e-15 -1.53484e-15',
           '2.14516e-15 2.09608e-15 7.30303e-16',
           '-3.2318e-16 -4.25666e-16 2.69997e-15',
           '5.21717e-16 1.28406e-15 1.95335e-15',
           '1.21199e-15 -1.25801e-19 1.18321e-15',
           'HPIFIT errors:  0.3, 0.3, 5.3, 0.4, 3.2 mm.',
           'HPI consistency of isotrak and hpifit is OK.',
           'HP fitting limits: err = 5.0 mm, gval = 0.980.',
           'Using 5 HPI coils: 83 143 203 263 323 Hz',  # actually came earlier
           ]

    log = log.getvalue().splitlines()
    assert_true(set(log) == set(msg), '\n' + '\n'.join(set(msg) - set(log)))

    # Then took the raw file, did this:
    raw.info['dig'][5]['r'][2] += 1.
    # And checked the result in MaxFilter, which changed the logging as:
    msg = msg[:8] + [
        'HPIFIT errors:  0.3, 0.3, 5.3, 999.7, 3.2 mm.',
        'Note: HPI coil 3 isotrak is adjusted by 5.3 mm!',
        'Note: HPI coil 5 isotrak is adjusted by 3.2 mm!'] + msg[-2:]
    with catch_logging() as log:
        _get_hpi_info(raw.info, adjust=True, verbose='debug')
    log = log.getvalue().splitlines()
    assert_true(set(log) == set(msg), '\n' + '\n'.join(set(msg) - set(log)))
Exemplo n.º 2
0
def test_chpi_adjust():
    """Test cHPI logging and adjustment."""
    raw = read_raw_fif(chpi_fif_fname, allow_maxshield='yes')
    with catch_logging() as log:
        _get_hpi_info(raw.info, adjust=True, verbose='debug')

    # Ran MaxFilter (with -list, -v, -movecomp, etc.), and got:
    msg = ['HPIFIT: 5 coils digitized in order 5 1 4 3 2',
           'HPIFIT: 3 coils accepted: 1 2 4',
           'Hpi coil moments (3 5):',
           '2.08542e-15 -1.52486e-15 -1.53484e-15',
           '2.14516e-15 2.09608e-15 7.30303e-16',
           '-3.2318e-16 -4.25666e-16 2.69997e-15',
           '5.21717e-16 1.28406e-15 1.95335e-15',
           '1.21199e-15 -1.25801e-19 1.18321e-15',
           'HPIFIT errors:  0.3, 0.3, 5.3, 0.4, 3.2 mm.',
           'HPI consistency of isotrak and hpifit is OK.',
           'HP fitting limits: err = 5.0 mm, gval = 0.980.',
           'Using 5 HPI coils: 83 143 203 263 323 Hz',  # actually came earlier
           ]

    log = log.getvalue().splitlines()
    assert_true(set(log) == set(msg), '\n' + '\n'.join(set(msg) - set(log)))

    # Then took the raw file, did this:
    raw.info['dig'][5]['r'][2] += 1.
    # And checked the result in MaxFilter, which changed the logging as:
    msg = msg[:8] + [
        'HPIFIT errors:  0.3, 0.3, 5.3, 999.7, 3.2 mm.',
        'Note: HPI coil 3 isotrak is adjusted by 5.3 mm!',
        'Note: HPI coil 5 isotrak is adjusted by 3.2 mm!'] + msg[-2:]
    with catch_logging() as log:
        _get_hpi_info(raw.info, adjust=True, verbose='debug')
    log = log.getvalue().splitlines()
    assert_true(set(log) == set(msg), '\n' + '\n'.join(set(msg) - set(log)))
Exemplo n.º 3
0
def test_simulate_raw_chpi():
    """Test simulation of raw data with cHPI"""
    with warnings.catch_warnings(record=True):  # MaxShield
        raw = Raw(raw_chpi_fname, allow_maxshield=True)
    sphere = make_sphere_model('auto', 'auto', raw.info)
    # make sparse spherical source space
    sphere_vol = tuple(sphere['r0'] * 1000.) + (sphere.radius * 1000.,)
    src = setup_volume_source_space('sample', sphere=sphere_vol, pos=70.)
    stc = _make_stc(raw, src)
    # simulate data with cHPI on
    raw_sim = simulate_raw(raw, stc, None, src, sphere, cov=None, chpi=False)
    # need to trim extra samples off this one
    raw_chpi = simulate_raw(raw, stc, None, src, sphere, cov=None, chpi=True,
                            head_pos=pos_fname)
    # test that the cHPI signals make some reasonable values
    psd_sim, freqs_sim = compute_raw_psd(raw_sim)
    psd_chpi, freqs_chpi = compute_raw_psd(raw_chpi)
    assert_array_equal(freqs_sim, freqs_chpi)
    hpi_freqs = _get_hpi_info(raw.info)[0]
    freq_idx = np.sort([np.argmin(np.abs(freqs_sim - f)) for f in hpi_freqs])
    picks_meg = pick_types(raw.info, meg=True, eeg=False)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)
    assert_allclose(psd_sim[picks_eeg], psd_chpi[picks_eeg], atol=1e-20)
    assert_true((psd_chpi[picks_meg][:, freq_idx] >
                 100 * psd_sim[picks_meg][:, freq_idx]).all())
    # test localization based on cHPI information
    trans_sim, rot_sim, t_sim = _calculate_chpi_positions(raw_chpi)
    trans, rot, t = get_chpi_positions(pos_fname)
    t -= raw.first_samp / raw.info['sfreq']
    _compare_positions((trans, rot, t), (trans_sim, rot_sim, t_sim),
                       max_dist=0.005)
Exemplo n.º 4
0
def test_simulate_raw_chpi():
    """Test simulation of raw data with cHPI."""
    raw = read_raw_fif(raw_chpi_fname, allow_maxshield='yes')
    picks = np.arange(len(raw.ch_names))
    picks = np.setdiff1d(picks, pick_types(raw.info, meg=True, eeg=True)[::4])
    raw.load_data().pick_channels([raw.ch_names[pick] for pick in picks])
    raw.info.normalize_proj()
    sphere = make_sphere_model('auto', 'auto', raw.info)
    # make sparse spherical source space
    sphere_vol = tuple(sphere['r0'] * 1000.) + (sphere.radius * 1000., )
    src = setup_volume_source_space('sample', sphere=sphere_vol, pos=70.)
    stc = _make_stc(raw, src)
    # simulate data with cHPI on
    raw_sim = simulate_raw(raw,
                           stc,
                           None,
                           src,
                           sphere,
                           cov=None,
                           chpi=False,
                           interp='zero',
                           use_cps=True)
    # need to trim extra samples off this one
    raw_chpi = simulate_raw(raw,
                            stc,
                            None,
                            src,
                            sphere,
                            cov=None,
                            chpi=True,
                            head_pos=pos_fname,
                            interp='zero',
                            use_cps=True)
    # test cHPI indication
    hpi_freqs, hpi_pick, hpi_ons = _get_hpi_info(raw.info)
    assert_allclose(raw_sim[hpi_pick][0], 0.)
    assert_allclose(raw_chpi[hpi_pick][0], hpi_ons.sum())
    # test that the cHPI signals make some reasonable values
    picks_meg = pick_types(raw.info, meg=True, eeg=False)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)

    for picks in [picks_meg[:3], picks_eeg[:3]]:
        psd_sim, freqs_sim = psd_welch(raw_sim, picks=picks)
        psd_chpi, freqs_chpi = psd_welch(raw_chpi, picks=picks)

        assert_array_equal(freqs_sim, freqs_chpi)
        freq_idx = np.sort(
            [np.argmin(np.abs(freqs_sim - f)) for f in hpi_freqs])
        if picks is picks_meg:
            assert_true(
                (psd_chpi[:, freq_idx] > 100 * psd_sim[:, freq_idx]).all())
        else:
            assert_allclose(psd_sim, psd_chpi, atol=1e-20)

    # test localization based on cHPI information
    quats_sim = _calculate_chpi_positions(raw_chpi, t_step_min=10.)
    quats = read_head_pos(pos_fname)
    _assert_quats(quats, quats_sim, dist_tol=5e-3, angle_tol=3.5)
Exemplo n.º 5
0
def test_simulate_raw_chpi():
    """Test simulation of raw data with cHPI"""
    with warnings.catch_warnings(record=True):  # MaxShield
        raw = Raw(raw_chpi_fname, allow_maxshield=True)
    sphere = make_sphere_model('auto', 'auto', raw.info)
    # make sparse spherical source space
    sphere_vol = tuple(sphere['r0'] * 1000.) + (sphere.radius * 1000., )
    src = setup_volume_source_space('sample', sphere=sphere_vol, pos=70.)
    stc = _make_stc(raw, src)
    # simulate data with cHPI on
    raw_sim = simulate_raw(raw, stc, None, src, sphere, cov=None, chpi=False)
    # need to trim extra samples off this one
    raw_chpi = simulate_raw(raw,
                            stc,
                            None,
                            src,
                            sphere,
                            cov=None,
                            chpi=True,
                            head_pos=pos_fname)
    # test cHPI indication
    hpi_freqs, _, hpi_pick, hpi_on, _ = _get_hpi_info(raw.info)
    assert_allclose(raw_sim[hpi_pick][0], 0.)
    assert_allclose(raw_chpi[hpi_pick][0], hpi_on)

    # test that the cHPI signals make some reasonable values
    picks_meg = pick_types(raw.info, meg=True, eeg=False)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)

    for picks in [picks_meg, picks_eeg]:
        psd_sim, freqs_sim = psd_welch(raw_sim, picks=picks)
        psd_chpi, freqs_chpi = psd_welch(raw_chpi, picks=picks)

        assert_array_equal(freqs_sim, freqs_chpi)
        freq_idx = np.sort(
            [np.argmin(np.abs(freqs_sim - f)) for f in hpi_freqs])
        if picks is picks_meg:
            assert_true(
                (psd_chpi[:, freq_idx] > 100 * psd_sim[:, freq_idx]).all())
        else:
            assert_allclose(psd_sim, psd_chpi, atol=1e-20)

    # test localization based on cHPI information
    trans_sim, rot_sim, t_sim = _calculate_chpi_positions(raw_chpi)
    trans, rot, t = get_chpi_positions(pos_fname)
    t -= raw.first_samp / raw.info['sfreq']
    _compare_positions((trans, rot, t), (trans_sim, rot_sim, t_sim),
                       max_dist=0.005)
Exemplo n.º 6
0
def test_simulate_raw_chpi():
    """Test simulation of raw data with cHPI."""
    raw = read_raw_fif(raw_chpi_fname, allow_maxshield='yes')
    picks = np.arange(len(raw.ch_names))
    picks = np.setdiff1d(picks, pick_types(raw.info, meg=True, eeg=True)[::4])
    raw.load_data().pick_channels([raw.ch_names[pick] for pick in picks])
    raw.info.normalize_proj()
    sphere = make_sphere_model('auto', 'auto', raw.info)
    # make sparse spherical source space
    sphere_vol = tuple(sphere['r0'] * 1000.) + (sphere.radius * 1000.,)
    src = setup_volume_source_space(sphere=sphere_vol, pos=70.)
    stc = _make_stc(raw, src)
    # simulate data with cHPI on
    with pytest.deprecated_call():
        raw_sim = simulate_raw(raw, stc, None, src, sphere, cov=None,
                               head_pos=pos_fname, interp='zero')
    # need to trim extra samples off this one
    with pytest.deprecated_call():
        raw_chpi = simulate_raw(raw, stc, None, src, sphere, cov=None,
                                chpi=True, head_pos=pos_fname, interp='zero')
    # test cHPI indication
    hpi_freqs, hpi_pick, hpi_ons = _get_hpi_info(raw.info)
    assert_allclose(raw_sim[hpi_pick][0], 0.)
    assert_allclose(raw_chpi[hpi_pick][0], hpi_ons.sum())
    # test that the cHPI signals make some reasonable values
    picks_meg = pick_types(raw.info, meg=True, eeg=False)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)

    for picks in [picks_meg[:3], picks_eeg[:3]]:
        psd_sim, freqs_sim = psd_welch(raw_sim, picks=picks)
        psd_chpi, freqs_chpi = psd_welch(raw_chpi, picks=picks)

        assert_array_equal(freqs_sim, freqs_chpi)
        freq_idx = np.sort([np.argmin(np.abs(freqs_sim - f))
                            for f in hpi_freqs])
        if picks is picks_meg:
            assert (psd_chpi[:, freq_idx] >
                    100 * psd_sim[:, freq_idx]).all()
        else:
            assert_allclose(psd_sim, psd_chpi, atol=1e-20)

    # test localization based on cHPI information
    quats_sim = _calculate_chpi_positions(raw_chpi, t_step_min=10.)
    quats = read_head_pos(pos_fname)
    _assert_quats(quats, quats_sim, dist_tol=5e-3, angle_tol=3.5)
Exemplo n.º 7
0
def test_simulate_raw_chpi():
    """Test simulation of raw data with cHPI."""
    raw = read_raw_fif(raw_chpi_fname, allow_maxshield='yes',
                       add_eeg_ref=False)
    sphere = make_sphere_model('auto', 'auto', raw.info)
    # make sparse spherical source space
    sphere_vol = tuple(sphere['r0'] * 1000.) + (sphere.radius * 1000.,)
    src = setup_volume_source_space('sample', sphere=sphere_vol, pos=70.)
    stc = _make_stc(raw, src)
    # simulate data with cHPI on
    raw_sim = simulate_raw(raw, stc, None, src, sphere, cov=None, chpi=False)
    # need to trim extra samples off this one
    raw_chpi = simulate_raw(raw, stc, None, src, sphere, cov=None, chpi=True,
                            head_pos=pos_fname)
    # test cHPI indication
    hpi_freqs, _, hpi_pick, hpi_ons = _get_hpi_info(raw.info)[:4]
    assert_allclose(raw_sim[hpi_pick][0], 0.)
    assert_allclose(raw_chpi[hpi_pick][0], hpi_ons.sum())
    # test that the cHPI signals make some reasonable values
    picks_meg = pick_types(raw.info, meg=True, eeg=False)
    picks_eeg = pick_types(raw.info, meg=False, eeg=True)

    for picks in [picks_meg, picks_eeg]:
        psd_sim, freqs_sim = psd_welch(raw_sim, picks=picks)
        psd_chpi, freqs_chpi = psd_welch(raw_chpi, picks=picks)

        assert_array_equal(freqs_sim, freqs_chpi)
        freq_idx = np.sort([np.argmin(np.abs(freqs_sim - f))
                           for f in hpi_freqs])
        if picks is picks_meg:
            assert_true((psd_chpi[:, freq_idx] >
                         100 * psd_sim[:, freq_idx]).all())
        else:
            assert_allclose(psd_sim, psd_chpi, atol=1e-20)

    # test localization based on cHPI information
    quats_sim = _calculate_chpi_positions(raw_chpi)
    trans_sim, rot_sim, t_sim = head_pos_to_trans_rot_t(quats_sim)
    trans, rot, t = head_pos_to_trans_rot_t(read_head_pos(pos_fname))
    t -= raw.first_samp / raw.info['sfreq']
    _compare_positions((trans, rot, t), (trans_sim, rot_sim, t_sim),
                       max_dist=0.005)
Exemplo n.º 8
0
def plot_chpi_snr_raw(raw,
                      win_length,
                      n_harmonics=None,
                      show=True,
                      verbose=True):
    """Compute and plot cHPI SNR from raw data

    Parameters
    ----------
    win_length : float
        Length of window to use for SNR estimates (seconds). A longer window
        will naturally include more low frequency power, resulting in lower
        SNR.
    n_harmonics : int or None
        Number of line frequency harmonics to include in the model. If None,
        use all harmonics up to the MEG analog lowpass corner.
    show : bool
        Show figure if True.

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        cHPI SNR as function of time, residual variance.

    Notes
    -----
    A general linear model including cHPI and line frequencies is fit into
    each data window. The cHPI power obtained from the model is then divided
    by the residual variance (variance of signal unexplained by the model) to
    obtain the SNR.

    The SNR may decrease either due to decrease of cHPI amplitudes (e.g.
    head moving away from the helmet), or due to increase in the residual
    variance. In case of broadband interference that overlaps with the cHPI
    frequencies, the resulting decreased SNR accurately reflects the true
    situation. However, increased narrowband interference outside the cHPI
    and line frequencies would also cause an increase in the residual variance,
    even though it wouldn't necessarily affect estimation of the cHPI
    amplitudes. Thus, this method is intended for a rough overview of cHPI
    signal quality. A more accurate picture of cHPI quality (at an increased
    computational cost) can be obtained by examining the goodness-of-fit of
    the cHPI coil fits.
    """
    import matplotlib.pyplot as plt
    from mne.chpi import _get_hpi_info

    # plotting parameters
    legend_fontsize = 6
    title_fontsize = 10
    tick_fontsize = 10
    label_fontsize = 10

    # get some info from fiff
    sfreq = raw.info['sfreq']
    linefreq = raw.info['line_freq']
    if n_harmonics is not None:
        linefreqs = (np.arange(n_harmonics + 1) + 1) * linefreq
    else:
        linefreqs = np.arange(linefreq, raw.info['lowpass'], linefreq)
    buflen = int(win_length * sfreq)
    if buflen <= 0:
        raise ValueError('Window length should be >0')
    cfreqs = _get_hpi_info(raw.info, verbose=False)[0]
    if verbose:
        print('Nominal cHPI frequencies: %s Hz' % cfreqs)
        print('Sampling frequency: %s Hz' % sfreq)
        print('Using line freqs: %s Hz' % linefreqs)
        print('Using buffers of %s samples = %s seconds\n' %
              (buflen, buflen / sfreq))

    pick_meg = pick_types(raw.info, meg=True, exclude=[])
    pick_mag = pick_types(raw.info, meg='mag', exclude=[])
    pick_grad = pick_types(raw.info, meg='grad', exclude=[])
    nchan = len(pick_meg)
    # grad and mag indices into an array that already has meg channels only
    pick_mag_ = np.in1d(pick_meg, pick_mag).nonzero()[0]
    pick_grad_ = np.in1d(pick_meg, pick_grad).nonzero()[0]

    # create general linear model for the data
    t = np.arange(buflen) / float(sfreq)
    model = np.empty((len(t), 2 + 2 * (len(linefreqs) + len(cfreqs))))
    model[:, 0] = t
    model[:, 1] = np.ones(t.shape)
    # add sine and cosine term for each freq
    allfreqs = np.concatenate([linefreqs, cfreqs])
    model[:, 2::2] = np.cos(2 * np.pi * t[:, np.newaxis] * allfreqs)
    model[:, 3::2] = np.sin(2 * np.pi * t[:, np.newaxis] * allfreqs)
    inv_model = linalg.pinv(model)

    # drop last buffer to avoid overrun
    bufs = np.arange(0, raw.n_times, buflen)[:-1]
    tvec = bufs / sfreq
    snr_avg_grad = np.zeros([len(cfreqs), len(bufs)])
    hpi_pow_grad = np.zeros([len(cfreqs), len(bufs)])
    snr_avg_mag = np.zeros([len(cfreqs), len(bufs)])
    resid_vars = np.zeros([nchan, len(bufs)])
    for ind, buf0 in enumerate(bufs):
        if verbose:
            print('Buffer %s/%s' % (ind + 1, len(bufs)))
        megbuf = raw[pick_meg, buf0:buf0 + buflen][0].T
        coeffs = np.dot(inv_model, megbuf)
        coeffs_hpi = coeffs[2 + 2 * len(linefreqs):]
        resid_vars[:, ind] = np.var(megbuf - np.dot(model, coeffs), 0)
        # get total power by combining sine and cosine terms
        # sinusoidal of amplitude A has power of A**2/2
        hpi_pow = (coeffs_hpi[0::2, :]**2 + coeffs_hpi[1::2, :]**2) / 2
        hpi_pow_grad[:, ind] = hpi_pow[:, pick_grad_].mean(1)
        # divide average HPI power by average variance
        snr_avg_grad[:, ind] = hpi_pow_grad[:, ind] / \
            resid_vars[pick_grad_, ind].mean()
        snr_avg_mag[:, ind] = hpi_pow[:, pick_mag_].mean(1) / \
            resid_vars[pick_mag_, ind].mean()

    cfreqs_legend = ['%s Hz' % fre for fre in cfreqs]
    fig, axs = plt.subplots(4, 1, sharex=True)

    # SNR plots for gradiometers and magnetometers
    ax = axs[0]
    lines1 = ax.plot(tvec, 10 * np.log10(snr_avg_grad.T))
    lines1_med = ax.plot(tvec,
                         10 * np.log10(np.median(snr_avg_grad, axis=0)),
                         lw=2,
                         ls=':',
                         color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='SNR (dB)')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power / mean residual variance, gradiometers',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    ax = axs[1]
    lines2 = ax.plot(tvec, 10 * np.log10(snr_avg_mag.T))
    lines2_med = ax.plot(tvec,
                         10 * np.log10(np.median(snr_avg_mag, axis=0)),
                         lw=2,
                         ls=':',
                         color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='SNR (dB)')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power / mean residual variance, magnetometers',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    ax = axs[2]
    lines3 = ax.plot(tvec, hpi_pow_grad.T)
    lines3_med = ax.plot(tvec,
                         np.median(hpi_pow_grad, axis=0),
                         lw=2,
                         ls=':',
                         color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='Power (T/m)$^2$')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power, gradiometers', fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    # residual (unexplained) variance as function of time
    ax = axs[3]
    cls = plt.get_cmap('plasma')(np.linspace(0., 0.7, len(pick_meg)))
    ax.set_prop_cycle(color=cls)
    ax.semilogy(tvec, resid_vars[pick_grad_, :].T, alpha=.4)
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='Var. (T/m)$^2$', xlabel='Time (s)')
    ax.xaxis.label.set_fontsize(label_fontsize)
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Residual (unexplained) variance, all gradiometer channels',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    tight_layout(pad=.5, w_pad=.1, h_pad=.2)  # from mne.viz
    # tight_layout will screw these up
    ax = axs[0]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    # order curve legends according to mean of data
    sind = np.argsort(snr_avg_grad.mean(axis=1))[::-1]
    handles = [lines1[i] for i in sind]
    handles.append(lines1_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    leg_kwargs = dict(
        prop={'size': legend_fontsize},
        bbox_to_anchor=(
            1.02,
            0.5,
        ),
        loc='center left',
        borderpad=1,
        handlelength=1,
    )
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[1]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    sind = np.argsort(snr_avg_mag.mean(axis=1))[::-1]
    handles = [lines2[i] for i in sind]
    handles.append(lines2_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[2]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    sind = np.argsort(hpi_pow_grad.mean(axis=1))[::-1]
    handles = [lines3[i] for i in sind]
    handles.append(lines3_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[3]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    if show:
        plt.show()

    return fig