def noise_reducer_4raw_data(fname_raw,
                            raw=None,
                            signals=[],
                            noiseref=[],
                            detrending=None,
                            tmin=None,
                            tmax=None,
                            reflp=None,
                            refhp=None,
                            refnotch=None,
                            exclude_artifacts=True,
                            checkresults=True,
                            fif_extention="-raw.fif",
                            fif_postfix="nr",
                            reject={
                                'grad': 4000e-13,
                                'mag': 4e-12,
                                'eeg': 40e-6,
                                'eog': 250e-6
                            },
                            complementary_signal=False,
                            fnout=None,
                            verbose=False,
                            save=True):
    """Apply noise reduction to signal channels using reference channels.
        
       !!! ONLY ONE RAW Obj Interface Version FB !!!
           
    Parameters
    ----------
    fname_raw : rawfile name

    raw     : fif raw object

    signals : list of string
              List of channels to compensate using noiseref.
              If empty use the meg signal channels.
    noiseref : list of string | str
              List of channels to use as noise reference.
              If empty use the magnetic reference channsls (default).
    signals and noiseref may contain regexp, which are resolved
    using mne.pick_channels_regexp(). All other channels are copied.
    tmin : lower latency bound for weight-calc [start of trace]
    tmax : upper latency bound for weight-calc [ end  of trace]
           Weights are calc'd for (tmin,tmax), but applied to entire data set
    refhp : high-pass frequency for reference signal filter [None]
    reflp :  low-pass frequency for reference signal filter [None]
            reflp < refhp: band-stop filter
            reflp > refhp: band-pass filter
            reflp is not None, refhp is None: low-pass filter
            reflp is None, refhp is not None: high-pass filter
    refnotch : (base) notch frequency for reference signal filter [None]
               use raw(ref)-notched(ref) as reference signal
    exclude_artifacts: filter signal-channels thru _is_good() [True]
                       (parameters are at present hard-coded!)
    complementary_signal : replaced signal by traces that would be subtracted [False]
                           (can be useful for debugging)
    checkresults : boolean to control internal checks and overall success [True]

    reject =  dict for rejection threshold 
              units:
              grad:    T / m (gradiometers)
              mag:     T (magnetometers)
              eeg/eog: uV (EEG channels)
              default=>{'grad':4000e-13,'mag':4e-12,'eeg':40e-6,'eog':250e-6}
              
    save : save data to fif file

    Outputfile:
    -------
    <wawa>,nr-raw.fif for input <wawa>-raw.fif

    Returns
    -------
    TBD

    Bugs
    ----
    - artifact checking is incomplete (and with arb. window of tstep=0.2s)
    - no accounting of channels used as signal/reference
    - non existing input file handled ungracefully
    """

    tc0 = time.clock()
    tw0 = time.time()

    if type(complementary_signal) != bool:
        raise ValueError("Argument complementary_signal must be of type bool")

    raw, fname_raw = jumeg_base.get_raw_obj(fname_raw, raw=raw)

    if detrending:
        raw = perform_detrending(None, raw=raw, save=False)

    tc1 = time.clock()
    tw1 = time.time()

    if verbose:
        print ">>> loading raw data took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tc0), (tw1 - tw0))

    # Time window selection
    # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
    # tstep is used in artifact detection
    # tmin,tmax variables must not be changed here!
    if tmin is None:
        itmin = 0
    else:
        itmin = int(floor(tmin * raw.info['sfreq']))
    if tmax is None:
        itmax = raw.last_samp
    else:
        itmax = int(ceil(tmax * raw.info['sfreq']))

    if itmax - itmin < 2:
        raise ValueError(
            "Time-window for noise compensation empty or too short")

    if verbose:
        print ">>> Set time-range to [%7.3f,%7.3f]" % \
              (raw.index_as_time(itmin)[0], raw.index_as_time(itmax)[0])

    if signals is None or len(signals) == 0:
        sigpick = jumeg_base.pick_meg_nobads(raw)
    else:
        sigpick = channel_indices_from_list(raw.info['ch_names'][:], signals,
                                            raw.info.get('bads'))
    nsig = len(sigpick)
    if nsig == 0:
        raise ValueError("No channel selected for noise compensation")

    if noiseref is None or len(noiseref) == 0:
        # References are not limited to 4D ref-chans, but can be anything,
        # incl. ECG or powerline monitor.
        if verbose:
            print ">>> Using all refchans."

        refexclude = "bads"
        refpick = jumeg_base.pick_ref_nobads(raw)
    else:
        refpick = channel_indices_from_list(raw.info['ch_names'][:], noiseref,
                                            raw.info.get('bads'))
    nref = len(refpick)
    if nref == 0:
        raise ValueError("No channel selected as noise reference")

    if verbose:
        print ">>> sigpick: %3d chans, refpick: %3d chans" % (nsig, nref)

    if reflp is None and refhp is None and refnotch is None:
        use_reffilter = False
        use_refantinotch = False
    else:
        use_reffilter = True
        if verbose:
            print "########## Filter reference channels:"

        use_refantinotch = False
        if refnotch is not None:
            if reflp is None and reflp is None:
                use_refantinotch = True
                freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                if verbose:
                    print ">>> notches at freq %.1f and harmonics below %.1f" % (
                        refnotch, freqlast)
            else:
                raise ValueError("Cannot specify notch- and high-/low-pass"
                                 "reference filter together")
        else:
            if verbose:
                if reflp is not None:
                    print ">>>  low-pass with cutoff-freq %.1f" % reflp
                if refhp is not None:
                    print ">>> high-pass with cutoff-freq %.1f" % refhp

        # Adapt followg drop-chans cmd to use 'all-but-refpick'
        droplist = [
            raw.info['ch_names'][k] for k in xrange(raw.info['nchan'])
            if not k in refpick
        ]
        tct = time.clock()
        twt = time.time()
        fltref = raw.drop_channels(droplist, copy=True)
        if use_refantinotch:
            rawref = raw.drop_channels(droplist, copy=True)
            freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
            fltref.notch_filter(np.arange(refnotch, freqlast, refnotch),
                                picks=np.array(xrange(nref)),
                                method='iir')
            fltref._data = (rawref._data - fltref._data)
        else:
            fltref.filter(refhp,
                          reflp,
                          picks=np.array(xrange(nref)),
                          method='iir')
        tc1 = time.clock()
        tw1 = time.time()
        if verbose:
            print ">>> filtering ref-chans  took %.1f ms (%.2f s walltime)" % (
                1000. * (tc1 - tct), (tw1 - twt))

    if verbose:
        print "########## Calculating sig-ref/ref-ref-channel covariances:"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and infosig entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.

    #--- !!! FB put to kwargs

    #reject = dict(grad=4000e-13, # T / m (gradiometers)
    #              mag=4e-12,     # T (magnetometers)
    #              eeg=40e-6,     # uV (EEG channels)
    #              eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # Read data in chunks:
    tstep = 0.2
    itstep = int(ceil(tstep * raw.info['sfreq']))
    sigmean = 0
    refmean = 0
    sscovdata = 0
    srcovdata = 0
    rrcovdata = 0
    n_samples = 0

    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]

        if not exclude_artifacts or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
                    ignore_chs=raw.info['bads']):
            sigmean += raw_segmentsig.sum(axis=1)
            refmean += raw_segmentref.sum(axis=1)
            sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
            srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
            rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))
    if n_samples <= 1:
        raise ValueError('Too few samples to calculate weights')
    sigmean /= n_samples
    refmean /= n_samples
    sscovdata -= n_samples * sigmean[:] * sigmean[:]
    sscovdata /= (n_samples - 1)
    srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
    srcovdata /= (n_samples - 1)
    rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
    rrcovdata /= (n_samples - 1)
    sscovinit = np.copy(sscovdata)
    if verbose:
        print ">>> Normalize srcov..."

    rrslope = copy.copy(rrcovdata)
    for iref in xrange(nref):
        dtmp = rrcovdata[iref, iref]
        if dtmp > TINY:
            srcovdata[:, iref] /= dtmp
            rrslope[:, iref] /= dtmp
        else:
            srcovdata[:, iref] = 0.
            rrslope[:, iref] = 0.

    if verbose:
        print ">>> Number of samples used : %d" % n_samples
        tc1 = time.clock()
        tw1 = time.time()
        print ">>> sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        if verbose:
            print "########## Calculated initial signal channel covariance:"
            # Calculate initial signal channel covariance:
            # (only used as quality measure)
            print ">>> initl rt(avg sig pwr) = %12.5e" % np.sqrt(
                np.mean(sscovdata))
            for i in xrange(5):
                print ">>> initl signal-rms[%3d] = %12.5e" % (
                    i, np.sqrt(sscovdata.flatten()[i]))
            print ">>>"

    U, s, V = np.linalg.svd(rrslope, full_matrices=True)
    if verbose:
        print ">>> singular values:"
        print s
        print ">>> Applying cutoff for smallest SVs:"

    dtmp = s.max() * SVD_RELCUTOFF
    s *= (abs(s) >= dtmp)
    sinv = [1. / s[k] if s[k] != 0. else 0. for k in xrange(nref)]
    if verbose:
        print ">>> singular values (after cutoff):"
        print s

    stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
    if verbose:
        print ">>> Testing svd-result: %s" % stat
        if not stat:
            print "    (Maybe due to SV-cutoff?)"

    # Solve for inverse coefficients:
    # Set RRinv.tr=U diag(sinv) V
    RRinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
    if checkresults:
        stat = np.allclose(np.identity(nref), np.dot(RRinv, rrslope))
        if stat:
            if verbose:
                print ">>> Testing RRinv-result (should be unit-matrix): ok"
        else:
            print ">>> Testing RRinv-result (should be unit-matrix): failed"
            print np.transpose(np.dot(RRinv, rrslope))
            print ">>>"

    if verbose:
        print "########## Calc weight matrix..."

    # weights-matrix will be somewhat larger than necessary,
    # (to simplify indexing in compensation loop):
    weights = np.zeros((raw._data.shape[0], nref))
    for isig in xrange(nsig):
        for iref in xrange(nref):
            weights[sigpick[isig], iref] = np.dot(srcovdata[isig, :],
                                                  RRinv[:, iref])

    if verbose:
        print "########## Compensating signal channels:"
        if complementary_signal:
            print ">>> Caveat: REPLACING signal by compensation signal"

    tct = time.clock()
    twt = time.time()

    # Work on entire data stream:
    for isl in xrange(raw._data.shape[1]):
        slice = np.take(raw._data, [isl], axis=1)
        if use_reffilter:
            refslice = np.take(fltref._data, [isl], axis=1)
            refarr = refslice[:].flatten() - refmean
            # refarr = fltres[:,isl]-refmean
        else:
            refarr = slice[refpick].flatten() - refmean
        subrefarr = np.dot(weights[:], refarr)

        if not complementary_signal:
            raw._data[:, isl] -= subrefarr
        else:
            raw._data[:, isl] = subrefarr

        if (isl % 10000 == 0) and verbose:
            print "\rProcessed slice %6d" % isl

    if verbose:
        print "\nDone."
        tc1 = time.clock()
        tw1 = time.time()
        print ">>> compensation loop took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        if verbose:
            print "########## Calculating final signal channel covariance:"
        # Calculate final signal channel covariance:
        # (only used as quality measure)
        tct = time.clock()
        twt = time.time()
        sigmean = 0
        sscovdata = 0
        n_samples = 0
        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            # Artifacts found here will probably differ from pre-noisered artifacts!
            if not exclude_artifacts or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                        flat=None, ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                n_samples += raw_segmentsig.shape[1]
        sigmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)

        if verbose:
            print ">>> no channel got worse: ", np.all(
                np.less_equal(sscovdata, sscovinit))
            print ">>> final rt(avg sig pwr) = %12.5e" % np.sqrt(
                np.mean(sscovdata))
            for i in xrange(5):
                print ">>> final signal-rms[%3d] = %12.5e" % (
                    i, np.sqrt(sscovdata.flatten()[i]))
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> signal covar-calc took %.1f ms (%.2f s walltime)" % (
                1000. * (tc1 - tct), (tw1 - twt))
            print ">>>"

#--- fb update 21.07.2015
    fname_out = jumeg_base.get_fif_name(raw=raw,
                                        postfix=fif_postfix,
                                        extention=fif_extention)

    if save:
        jumeg_base.apply_save_mne_data(raw, fname=fname_out, overwrite=True)

    tc1 = time.clock()
    tw1 = time.time()
    if verbose:
        print ">>> Total run took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tc0),
                                                                (tw1 - tw0))

    return raw, fname_out
def test_noise_reducer():

    data_path = os.environ['SUBJECTS_DIR']
    subject = os.environ['SUBJECT']

    dname = data_path + '/' + 'empty_room_files' + '/109925_empty_room_file-raw.fif'
    subjects_dir = data_path + '/subjects'
    #
    checkresults = True
    exclart = False
    use_reffilter = True
    refflt_lpfreq = 52.
    refflt_hpfreq = 48.

    print "########## before of noisereducer call ##########"
    sigchanlist = ['MEG ..1', 'MEG ..3', 'MEG ..5', 'MEG ..7', 'MEG ..9']
    sigchanlist = None
    refchanlist = ['RFM 001', 'RFM 003', 'RFM 005', 'RFG ...']
    tmin = 15.
    noise_reducer(dname,
                  signals=sigchanlist,
                  noiseref=refchanlist,
                  tmin=tmin,
                  reflp=refflt_lpfreq,
                  refhp=refflt_hpfreq,
                  exclude_artifacts=exclart,
                  complementary_signal=True)
    print "########## behind of noisereducer call ##########"

    print "########## Read raw data:"
    tc0 = time.clock()
    tw0 = time.time()
    raw = mne.io.Raw(dname, preload=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "loading raw data  took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tc0),
                                                                (tw1 - tw0))

    # Time window selection
    # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
    # tstep is used in artifact detection
    tmax = raw.index_as_time(raw.last_samp)[0]
    tstep = 0.2
    itmin = int(floor(tmin * raw.info['sfreq']))
    itmax = int(ceil(tmax * raw.info['sfreq']))
    itstep = int(ceil(tstep * raw.info['sfreq']))
    print ">>> Set time-range to [%7.3f,%7.3f]" % (tmin, tmax)

    if sigchanlist is None:
        sigpick = mne.pick_types(raw.info,
                                 meg='mag',
                                 eeg=False,
                                 stim=False,
                                 eog=False,
                                 exclude='bads')
    else:
        sigpick = channel_indices_from_list(raw.info['ch_names'][:],
                                            sigchanlist)
    nsig = len(sigpick)
    print "sigpick: %3d chans" % nsig
    if nsig == 0:
        raise ValueError("No channel selected for noise compensation")

    if refchanlist is None:
        # References are not limited to 4D ref-chans, but can be anything,
        # incl. ECG or powerline monitor.
        print ">>> Using all refchans."
        refexclude = "bads"
        refpick = mne.pick_types(raw.info,
                                 ref_meg=True,
                                 meg=False,
                                 eeg=False,
                                 stim=False,
                                 eog=False,
                                 exclude=refexclude)
    else:
        refpick = channel_indices_from_list(raw.info['ch_names'][:],
                                            refchanlist)
        print "refpick = '%s'" % refpick
    nref = len(refpick)
    print "refpick: %3d chans" % nref
    if nref == 0:
        raise ValueError("No channel selected as noise reference")

    print "########## Refchan geo data:"
    # This is just for info to locate special 4D-refs.
    for iref in refpick:
        print raw.info['chs'][iref]['ch_name'], raw.info['chs'][iref]['loc'][
            0:3]
    print ""

    if use_reffilter:
        print "########## Filter reference channels:"
        if refflt_lpfreq is not None:
            print " low-pass with cutoff-freq %.1f" % refflt_lpfreq
        if refflt_hpfreq is not None:
            print "high-pass with cutoff-freq %.1f" % refflt_hpfreq
        # Adapt followg drop-chans cmd to use 'all-but-refpick'
        droplist = [
            raw.info['ch_names'][k] for k in xrange(raw.info['nchan'])
            if not k in refpick
        ]
        fltref = raw.drop_channels(droplist, copy=True)
        tct = time.clock()
        twt = time.time()
        fltref.filter(refflt_hpfreq,
                      refflt_lpfreq,
                      picks=np.array(xrange(nref)),
                      method='iir')
        tc1 = time.clock()
        tw1 = time.time()
        print "filtering ref-chans  took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances:"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(
        grad=4000e-13,  # T / m (gradiometers)
        mag=4e-12,  # T (magnetometers)
        eeg=40e-6,  # uV (EEG channels)
        eog=250e-6)  # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    sigmean = 0
    refmean = 0
    sscovdata = 0
    srcovdata = 0
    rrcovdata = 0
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            sigmean += raw_segmentsig.sum(axis=1)
            refmean += raw_segmentref.sum(axis=1)
            sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
            srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
            rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sigmean /= n_samples
    refmean /= n_samples
    sscovdata -= n_samples * sigmean[:] * sigmean[:]
    sscovdata /= (n_samples - 1)
    srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
    srcovdata /= (n_samples - 1)
    rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
    rrcovdata /= (n_samples - 1)
    sscovinit = sscovdata
    print "Normalize srcov..."
    rrslopedata = copy.copy(rrcovdata)
    for iref in xrange(nref):
        dtmp = rrcovdata[iref][iref]
        if dtmp > TINY:
            for isig in xrange(nsig):
                srcovdata[isig][iref] /= dtmp
            for jref in xrange(nref):
                rrslopedata[jref][iref] /= dtmp
        else:
            for isig in xrange(nsig):
                srcovdata[isig][iref] = 0.
            for jref in xrange(nref):
                rrslopedata[jref][iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. *
                                                                   (tc1 - tct),
                                                                   (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances (robust):"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (usg B.P.Welford, "Note on a method for calculating corrected sums
    #                   of squares and products", Technometrics4 (1962) 419-420)
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(
        grad=4000e-13,  # T / m (gradiometers)
        mag=4e-12,  # T (magnetometers)
        eeg=40e-6,  # uV (EEG channels)
        eog=250e-6)  # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    smean = np.zeros(nsig)
    smold = np.zeros(nsig)
    rmean = np.zeros(nref)
    rmold = np.zeros(nref)
    sscov = 0
    srcov = 0
    rrcov = np.zeros((nref, nref))
    srcov = np.zeros((nsig, nref))
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            for isl in xrange(raw_segmentsig.shape[1]):
                nsl = isl + n_samples + 1
                cnslm1dnsl = float((nsl - 1)) / float(nsl)
                sslsubmean = (raw_segmentsig[:, isl] - smold)
                rslsubmean = (raw_segmentref[:, isl] - rmold)
                smean = smold + sslsubmean / nsl
                rmean = rmold + rslsubmean / nsl
                sscov += sslsubmean * (raw_segmentsig[:, isl] - smean)
                srcov += cnslm1dnsl * np.dot(sslsubmean.reshape(
                    (nsig, 1)), rslsubmean.reshape((1, nref)))
                rrcov += cnslm1dnsl * np.dot(rslsubmean.reshape(
                    (nref, 1)), rslsubmean.reshape((1, nref)))
                smold = smean
                rmold = rmean
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sscov /= (n_samples - 1)
    srcov /= (n_samples - 1)
    rrcov /= (n_samples - 1)
    print "Normalize srcov..."
    rrslope = copy.copy(rrcov)
    for iref in xrange(nref):
        dtmp = rrcov[iref][iref]
        if dtmp > TINY:
            srcov[:, iref] /= dtmp
            rrslope[:, iref] /= dtmp
        else:
            srcov[:, iref] = 0.
            rrslope[:, iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    print "Compare results with 'standard' values:"
    print "cmp(sigmean,smean):", np.allclose(smean, sigmean, atol=0.)
    print "cmp(refmean,rmean):", np.allclose(rmean, refmean, atol=0.)
    print "cmp(sscovdata,sscov):", np.allclose(sscov, sscovdata, atol=0.)
    print "cmp(srcovdata,srcov):", np.allclose(srcov, srcovdata, atol=0.)
    print "cmp(rrcovdata,rrcov):", np.allclose(rrcov, rrcovdata, atol=0.)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. *
                                                                   (tc1 - tct),
                                                                   (tw1 - twt))

    if checkresults:
        print "########## Calculated initial signal channel covariance:"
        # Calculate initial signal channel covariance:
        # (only used as quality measure)
        print "initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscov))
        for i in xrange(5):
            print "initl signal-rms[%3d] = %12.5e" % (
                i, np.sqrt(sscov.flatten()[i]))
        print " "
    if nref < 6:
        print "rrslope-entries:"
        for i in xrange(nref):
            print rrslope[i][:]

    U, s, V = np.linalg.svd(rrslope, full_matrices=True)
    print s

    print "Applying cutoff for smallest SVs:"
    dtmp = s.max() * SVD_RELCUTOFF
    sinv = np.zeros(nref)
    for i in xrange(nref):
        if abs(s[i]) >= dtmp:
            sinv[i] = 1. / s[i]
        else:
            s[i] = 0.
    # s *= (abs(s)>=dtmp)
    # sinv = ???
    print s
    stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
    print ">>> Testing svd-result: %s" % stat
    if not stat:
        print "    (Maybe due to SV-cutoff?)"

    # Solve for inverse coefficients:
    print ">>> Setting RRinvtr=U diag(sinv) V"
    RRinvtr = np.zeros((nref, nref))
    RRinvtr = np.dot(U, np.dot(np.diag(sinv), V))
    if checkresults:
        # print ">>> RRinvtr-result:"
        # print RRinvtr
        stat = np.allclose(np.identity(nref),
                           np.dot(rrslope.transpose(), RRinvtr))
        if stat:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): ok"
        else:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): failed"
            print np.dot(rrslope.transpose(), RRinvtr)
            # np.less_equal(np.abs(np.dot(rrslope.transpose(),RRinvtr)-np.identity(nref)),0.01*np.ones((nref,nref)))
        print ""

    print "########## Calc weight matrix..."
    # weights-matrix will be somewhat larger than necessary,
    # (to simplify indexing in compensation loop):
    weights = np.zeros((raw._data.shape[0], nref))
    for isig in xrange(nsig):
        for iref in xrange(nref):
            weights[sigpick[isig]][iref] = np.dot(srcov[isig][:],
                                                  RRinvtr[iref][:])

    if np.allclose(np.zeros(weights.shape), np.abs(weights), atol=1.e-8):
        print ">>> all weights are small (<=1.e-8)."
    else:
        print ">>> largest weight %12.5e" % np.max(np.abs(weights))
        wlrg = np.where(np.abs(weights) >= 0.99 * np.max(np.abs(weights)))
        for iwlrg in xrange(len(wlrg[0])):
            print ">>> weights[%3d,%2d] = %12.5e" % \
                  (wlrg[0][iwlrg], wlrg[1][iwlrg], weights[wlrg[0][iwlrg], wlrg[1][iwlrg]])

    if nref < 5:
        print "weights-entries for first sigchans:"
        for i in xrange(5):
            print 'weights[sp(%2d)][r]=[' % i + ' '.join(
                [' %+10.7f' % val for val in weights[sigpick[i]][:]]) + ']'

    print "########## Compensating signal channels:"
    tct = time.clock()
    twt = time.time()
    # data,times = raw[:,raw.time_as_index(tmin)[0]:raw.time_as_index(tmax)[0]:]
    # Work on entire data stream:
    for isl in xrange(raw._data.shape[1]):
        slice = np.take(raw._data, [isl], axis=1)
        if use_reffilter:
            refslice = np.take(fltref._data, [isl], axis=1)
            refarr = refslice[:].flatten() - rmean
            # refarr = fltres[:,isl]-rmean
        else:
            refarr = slice[refpick].flatten() - rmean
        subrefarr = np.dot(weights[:], refarr)
        # data[:,isl] -= subrefarr   will not modify raw._data?
        raw._data[:, isl] -= subrefarr
        if isl % 10000 == 0:
            print "\rProcessed slice %6d" % isl
    print "\nDone."
    tc1 = time.clock()
    tw1 = time.time()
    print "compensation loop took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tct),
                                                                (tw1 - twt))

    if checkresults:
        print "########## Calculating final signal channel covariance:"
        # Calculate final signal channel covariance:
        # (only used as quality measure)
        tct = time.clock()
        twt = time.time()
        sigmean = 0
        sscovdata = 0
        n_samples = 0
        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            # Artifacts found here will probably differ from pre-noisered artifacts!
            if not exclart or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                        flat=None, ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                n_samples += raw_segmentsig.shape[1]
        sigmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        print ">>> no channel got worse: ", np.all(
            np.less_equal(sscovdata, sscovinit))
        print "final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
        for i in xrange(5):
            print "final signal-rms[%3d] = %12.5e" % (
                i, np.sqrt(sscovdata.flatten()[i]))
        tc1 = time.clock()
        tw1 = time.time()
        print "signal covar-calc took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))
        print " "

    nrname = dname[:dname.rfind('-raw.fif')] + ',nold-raw.fif'
    print "Saving '%s'..." % nrname
    raw.save(nrname, overwrite=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "Total run         took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tc0),
                                                                (tw1 - tw0))
예제 #3
0
def noise_reducer(fname_raw, raw=None, signals=[], noiseref=[], detrending=None,
                  tmin=None, tmax=None, reflp=None, refhp=None, refnotch=None,
                  exclude_artifacts=True, checkresults=True, return_raw=False,
                  complementary_signal=False, fnout=None, verbose=False):

    """Apply noise reduction to signal channels using reference channels.

    Parameters
    ----------
    fname_raw : (list of) rawfile names
    raw : mne Raw objects
        Allows passing of raw object as well.
    signals : list of string
              List of channels to compensate using noiseref.
              If empty use the meg signal channels.
    noiseref : list of string | str
              List of channels to use as noise reference.
              If empty use the magnetic reference channsls (default).
    signals and noiseref may contain regexp, which are resolved
    using mne.pick_channels_regexp(). All other channels are copied.
    tmin : lower latency bound for weight-calc [start of trace]
    tmax : upper latency bound for weight-calc [ end  of trace]
           Weights are calc'd for (tmin,tmax), but applied to entire data set
    refhp : high-pass frequency for reference signal filter [None]
    reflp :  low-pass frequency for reference signal filter [None]
            reflp < refhp: band-stop filter
            reflp > refhp: band-pass filter
            reflp is not None, refhp is None: low-pass filter
            reflp is None, refhp is not None: high-pass filter
    refnotch : (base) notch frequency for reference signal filter [None]
               use raw(ref)-notched(ref) as reference signal
    exclude_artifacts: filter signal-channels thru _is_good() [True]
                       (parameters are at present hard-coded!)
    return_raw : bool
        If return_raw is true, the raw object is returned and raw file
        is not written to disk. It is suggested that this option be used in cases
        where the noise_reducer is applied multiple times. [False]
    complementary_signal : replaced signal by traces that would be subtracted [False]
                           (can be useful for debugging)
    detrending: boolean to ctrl subtraction of linear trend from all magn. chans [False]
    checkresults : boolean to control internal checks and overall success [True]

    Outputfile
    ----------
    <wawa>,nr-raw.fif for input <wawa>-raw.fif

    Returns
    -------
    If return_raw is True, then mne.io.Raw instance is returned.

    Bugs
    ----
    - artifact checking is incomplete (and with arb. window of tstep=0.2s)
    - no accounting of channels used as signal/reference
    - non existing input file handled ungracefully
    """

    if type(complementary_signal) != bool:
        raise ValueError("Argument complementary_signal must be of type bool")

    # handle error if Raw object passed with file list
    if raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined with one Raw object')

    # handle error if return_raw is requested with file list
    if return_raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined return_raw.'
                         'Please pass one file at a time.')

    # handle error if Raw object is passed with detrending option
    #TODO include perform_detrending for Raw objects
    if raw and detrending:
        raise ValueError('Please perform detrending on the raw file directly. Cannot perform'
                         'detrending on the raw object')

    fnraw = get_files_from_list(fname_raw)

    # loop across all filenames
    for fname in fnraw:

        if verbose:
            print "########## Read raw data:"

        tc0 = time.clock()
        tw0 = time.time()

        if raw is None:
            if detrending:
                raw = perform_detrending(fname, save=False)
            else:
                raw = mne.io.Raw(fname, preload=True)
        else:
            # perform sanity check to make sure Raw object and file are same
            if os.path.basename(fname) != os.path.basename(raw.info['filename']):
                warnings.warn('The file name within the Raw object and provided'
                              'fname are not the same. Please check again.')

        tc1 = time.clock()
        tw1 = time.time()

        if verbose:
            print ">>> loading raw data took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))

        # Time window selection
        # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
        # tstep is used in artifact detection
        # tmin,tmax variables must not be changed here!
        if tmin is None:
            itmin = 0
        else:
            itmin = int(floor(tmin * raw.info['sfreq']))
        if tmax is None:
            itmax = raw.last_samp
        else:
            itmax = int(ceil(tmax * raw.info['sfreq']))

        if itmax - itmin < 2:
            raise ValueError("Time-window for noise compensation empty or too short")

        if verbose:
            print ">>> Set time-range to [%7.3f,%7.3f]" % \
                  (raw.times[itmin], raw.times[itmax])

        if signals is None or len(signals) == 0:
            sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False,
                                     eog=False, exclude='bads')
        else:
            sigpick = channel_indices_from_list(raw.info['ch_names'][:], signals,
                                                raw.info.get('bads'))
        nsig = len(sigpick)
        if nsig == 0:
            raise ValueError("No channel selected for noise compensation")

        if noiseref is None or len(noiseref) == 0:
            # References are not limited to 4D ref-chans, but can be anything,
            # incl. ECG or powerline monitor.
            if verbose:
                print ">>> Using all refchans."
            refexclude = "bads"
            refpick = mne.pick_types(raw.info, ref_meg=True, meg=False, eeg=False,
                                     stim=False, eog=False, exclude='bads')
        else:
            refpick = channel_indices_from_list(raw.info['ch_names'][:], noiseref,
                                                raw.info.get('bads'))
        nref = len(refpick)
        if nref == 0:
            raise ValueError("No channel selected as noise reference")

        if verbose:
            print ">>> sigpick: %3d chans, refpick: %3d chans" % (nsig, nref)

        if reflp is None and refhp is None and refnotch is None:
            use_reffilter = False
            use_refantinotch = False
        else:
            use_reffilter = True
            if verbose:
                print "########## Filter reference channels:"

            use_refantinotch = False
            if refnotch is not None:
                if reflp is None and reflp is None:
                    use_refantinotch = True
                    freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                    if verbose:
                        print ">>> notches at freq %.1f and harmonics below %.1f" % (refnotch, freqlast)
                else:
                    raise ValueError("Cannot specify notch- and high-/low-pass"
                                     "reference filter together")
            else:
                if verbose:
                    if reflp is not None:
                        print ">>>  low-pass with cutoff-freq %.1f" % reflp
                    if refhp is not None:
                        print ">>> high-pass with cutoff-freq %.1f" % refhp

            # Adapt followg drop-chans cmd to use 'all-but-refpick'
            droplist = [raw.info['ch_names'][k] for k in xrange(raw.info['nchan']) if not k in refpick]
            tct = time.clock()
            twt = time.time()
            fltref = raw.copy().drop_channels(droplist)
            if use_refantinotch:
                rawref = raw.copy().drop_channels(droplist)
                freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                fltref.notch_filter(np.arange(refnotch, freqlast, refnotch),
                                    picks=np.array(xrange(nref)), method='iir')
                fltref._data = (rawref._data - fltref._data)
            else:
                fltref.filter(refhp, reflp, picks=np.array(xrange(nref)), method='iir')
            tc1 = time.clock()
            tw1 = time.time()
            if verbose:
                print ">>> filtering ref-chans  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if verbose:
            print "########## Calculating sig-ref/ref-ref-channel covariances:"
        # Calculate sig-ref/ref-ref-channel covariance:
        # (there is no need to calc inter-signal-chan cov,
        #  but there seems to be no appropriat fct available)
        # Here we copy the idea from compute_raw_data_covariance()
        # and truncate it as appropriate.
        tct = time.clock()
        twt = time.time()
        # The following reject and infosig entries are only
        # used in _is_good-calls.
        # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
        # ignore ref-channels (not covered by dict) and checks individual
        # data segments - artifacts across a buffer boundary are not found.
        reject = dict(grad=4000e-13, # T / m (gradiometers)
                      mag=4e-12,     # T (magnetometers)
                      eeg=40e-6,     # uV (EEG channels)
                      eog=250e-6)    # uV (EOG channels)

        infosig = copy.copy(raw.info)
        infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
        infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
        infosig['nchan'] = len(sigpick)
        idx_by_typesig = channel_indices_by_type(infosig)

        # Read data in chunks:
        tstep = 0.2
        itstep = int(ceil(tstep * raw.info['sfreq']))
        sigmean = 0
        refmean = 0
        sscovdata = 0
        srcovdata = 0
        rrcovdata = 0
        n_samples = 0

        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            if use_reffilter:
                raw_segmentref, times = fltref[:, first:last]
            else:
                raw_segmentref, times = raw[refpick, first:last]

            if not exclude_artifacts or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
                        ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                refmean += raw_segmentref.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
                rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
                n_samples += raw_segmentsig.shape[1]
            else:
                logger.info("Artefact detected in [%d, %d]" % (first, last))
        if n_samples <= 1:
            raise ValueError('Too few samples to calculate weights')
        sigmean /= n_samples
        refmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
        srcovdata /= (n_samples - 1)
        rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
        rrcovdata /= (n_samples - 1)
        sscovinit = np.copy(sscovdata)
        if verbose:
            print ">>> Normalize srcov..."

        rrslope = copy.copy(rrcovdata)
        for iref in xrange(nref):
            dtmp = rrcovdata[iref, iref]
            if dtmp > TINY:
                srcovdata[:, iref] /= dtmp
                rrslope[:, iref] /= dtmp
            else:
                srcovdata[:, iref] = 0.
                rrslope[:, iref] = 0.

        if verbose:
            print ">>> Number of samples used : %d" % n_samples
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if checkresults:
            if verbose:
                print "########## Calculated initial signal channel covariance:"
                # Calculate initial signal channel covariance:
                # (only used as quality measure)
                print ">>> initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
                for i in xrange(5):
                    print ">>> initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
                print ">>>"

        U, s, V = np.linalg.svd(rrslope, full_matrices=True)
        if verbose:
            print ">>> singular values:"
            print s
            print ">>> Applying cutoff for smallest SVs:"

        dtmp = s.max() * SVD_RELCUTOFF
        s *= (abs(s) >= dtmp)
        sinv = [1. / s[k] if s[k] != 0. else 0. for k in xrange(nref)]
        if verbose:
            print ">>> singular values (after cutoff):"
            print s

        stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
        if verbose:
            print ">>> Testing svd-result: %s" % stat
            if not stat:
                print "    (Maybe due to SV-cutoff?)"

        # Solve for inverse coefficients:
        # Set RRinv.tr=U diag(sinv) V
        RRinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
        if checkresults:
            stat = np.allclose(np.identity(nref), np.dot(RRinv, rrslope))
            if stat:
                if verbose:
                    print ">>> Testing RRinv-result (should be unit-matrix): ok"
            else:
                print ">>> Testing RRinv-result (should be unit-matrix): failed"
                print np.transpose(np.dot(RRinv, rrslope))
                print ">>>"

        if verbose:
            print "########## Calc weight matrix..."

        # weights-matrix will be somewhat larger than necessary,
        # (to simplify indexing in compensation loop):
        weights = np.zeros((raw._data.shape[0], nref))
        for isig in xrange(nsig):
            for iref in xrange(nref):
                weights[sigpick[isig],iref] = np.dot(srcovdata[isig,:], RRinv[:,iref])

        if verbose:
            print "########## Compensating signal channels:"
            if complementary_signal:
                print ">>> Caveat: REPLACING signal by compensation signal"

        tct = time.clock()
        twt = time.time()

        # Work on entire data stream:
        for isl in xrange(raw._data.shape[1]):
            slice = np.take(raw._data, [isl], axis=1)
            if use_reffilter:
                refslice = np.take(fltref._data, [isl], axis=1)
                refarr = refslice[:].flatten() - refmean
                # refarr = fltres[:,isl]-refmean
            else:
                refarr = slice[refpick].flatten() - refmean
            subrefarr = np.dot(weights[:], refarr)

            if not complementary_signal:
                raw._data[:, isl] -= subrefarr
            else:
                raw._data[:, isl] = subrefarr

            if (isl % 10000 == 0) and verbose:
                print "\rProcessed slice %6d" % isl

        if verbose:
            print "\nDone."
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> compensation loop took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if checkresults:
            if verbose:
                print "########## Calculating final signal channel covariance:"
            # Calculate final signal channel covariance:
            # (only used as quality measure)
            tct = time.clock()
            twt = time.time()
            sigmean = 0
            sscovdata = 0
            n_samples = 0
            for first in range(itmin, itmax, itstep):
                last = first + itstep
                if last >= itmax:
                    last = itmax
                raw_segmentsig, times = raw[sigpick, first:last]
                # Artifacts found here will probably differ from pre-noisered artifacts!
                if not exclude_artifacts or \
                   _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                            flat=None, ignore_chs=raw.info['bads']):
                    sigmean += raw_segmentsig.sum(axis=1)
                    sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                    n_samples += raw_segmentsig.shape[1]
            sigmean /= n_samples
            sscovdata -= n_samples * sigmean[:] * sigmean[:]
            sscovdata /= (n_samples - 1)
            if verbose:
                print ">>> no channel got worse: ", np.all(np.less_equal(sscovdata, sscovinit))
                print ">>> final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
                for i in xrange(5):
                    print ">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
                tc1 = time.clock()
                tw1 = time.time()
                print ">>> signal covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))
                print ">>>"

        if fnout is not None:
            fnoutloc = fnout
        else:
            fnoutloc = fname[:fname.rfind('-raw.fif')] + ',nr-raw.fif'

        if verbose:
            print ">>> Saving '%s'..." % fnoutloc

        if return_raw:
            return raw
        else:
            raw.save(fnoutloc, overwrite=True)

        tc1 = time.clock()
        tw1 = time.time()
        if verbose:
            print ">>> Total run took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))
예제 #4
0
def test_noise_reducer():

    data_path = os.environ['SUBJECTS_DIR']
    subject   = os.environ['SUBJECT']

    dname = data_path + '/' + 'empty_room_files' + '/109925_empty_room_file-raw.fif'
    subjects_dir = data_path + '/subjects'
    #
    checkresults = True
    exclart = False
    use_reffilter = True
    refflt_lpfreq = 52.
    refflt_hpfreq = 48.

    print "########## before of noisereducer call ##########"
    sigchanlist = ['MEG ..1', 'MEG ..3', 'MEG ..5', 'MEG ..7', 'MEG ..9']
    sigchanlist = None
    refchanlist = ['RFM 001', 'RFM 003', 'RFM 005', 'RFG ...']
    tmin = 15.
    noise_reducer(dname, signals=sigchanlist, noiseref=refchanlist, tmin=tmin,
                  reflp=refflt_lpfreq, refhp=refflt_hpfreq,
                  exclude_artifacts=exclart, complementary_signal=True)
    print "########## behind of noisereducer call ##########"

    print "########## Read raw data:"
    tc0 = time.clock()
    tw0 = time.time()
    raw = mne.io.Raw(dname, preload=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "loading raw data  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))

    # Time window selection
    # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
    # tstep is used in artifact detection
    tmax = raw.times[raw.last_samp]
    tstep = 0.2
    itmin = int(floor(tmin * raw.info['sfreq']))
    itmax = int(ceil(tmax * raw.info['sfreq']))
    itstep = int(ceil(tstep * raw.info['sfreq']))
    print ">>> Set time-range to [%7.3f,%7.3f]" % (tmin, tmax)

    if sigchanlist is None:
        sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False, eog=False, exclude='bads')
    else:
        sigpick = channel_indices_from_list(raw.info['ch_names'][:], sigchanlist)
    nsig = len(sigpick)
    print "sigpick: %3d chans" % nsig
    if nsig == 0:
        raise ValueError("No channel selected for noise compensation")

    if refchanlist is None:
        # References are not limited to 4D ref-chans, but can be anything,
        # incl. ECG or powerline monitor.
        print ">>> Using all refchans."
        refexclude = "bads"
        refpick = mne.pick_types(raw.info, ref_meg=True, meg=False, eeg=False,
                                 stim=False, eog=False, exclude=refexclude)
    else:
        refpick = channel_indices_from_list(raw.info['ch_names'][:], refchanlist)
        print "refpick = '%s'" % refpick
    nref = len(refpick)
    print "refpick: %3d chans" % nref
    if nref == 0:
        raise ValueError("No channel selected as noise reference")

    print "########## Refchan geo data:"
    # This is just for info to locate special 4D-refs.
    for iref in refpick:
        print raw.info['chs'][iref]['ch_name'], raw.info['chs'][iref]['loc'][0:3]
    print ""

    if use_reffilter:
        print "########## Filter reference channels:"
        if refflt_lpfreq is not None:
            print " low-pass with cutoff-freq %.1f" % refflt_lpfreq
        if refflt_hpfreq is not None:
            print "high-pass with cutoff-freq %.1f" % refflt_hpfreq
        # Adapt followg drop-chans cmd to use 'all-but-refpick'
        droplist = [raw.info['ch_names'][k] for k in xrange(raw.info['nchan']) if not k in refpick]
        fltref = raw.drop_channels(droplist, copy=True)
        tct = time.clock()
        twt = time.time()
        fltref.filter(refflt_hpfreq, refflt_lpfreq, picks=np.array(xrange(nref)), method='iir')
        tc1 = time.clock()
        tw1 = time.time()
        print "filtering ref-chans  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances:"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(grad=4000e-13, # T / m (gradiometers)
                  mag=4e-12,     # T (magnetometers)
                  eeg=40e-6,     # uV (EEG channels)
                  eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    sigmean = 0
    refmean = 0
    sscovdata = 0
    srcovdata = 0
    rrcovdata = 0
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            sigmean += raw_segmentsig.sum(axis=1)
            refmean += raw_segmentref.sum(axis=1)
            sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
            srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
            rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sigmean /= n_samples
    refmean /= n_samples
    sscovdata -= n_samples * sigmean[:] * sigmean[:]
    sscovdata /= (n_samples - 1)
    srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
    srcovdata /= (n_samples - 1)
    rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
    rrcovdata /= (n_samples - 1)
    sscovinit = sscovdata
    print "Normalize srcov..."
    rrslopedata = copy.copy(rrcovdata)
    for iref in xrange(nref):
        dtmp = rrcovdata[iref][iref]
        if dtmp > TINY:
            for isig in xrange(nsig):
                srcovdata[isig][iref] /= dtmp
            for jref in xrange(nref):
                rrslopedata[jref][iref] /= dtmp
        else:
            for isig in xrange(nsig):
                srcovdata[isig][iref] = 0.
            for jref in xrange(nref):
                rrslopedata[jref][iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances (robust):"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (usg B.P.Welford, "Note on a method for calculating corrected sums
    #                   of squares and products", Technometrics4 (1962) 419-420)
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(grad=4000e-13, # T / m (gradiometers)
                  mag=4e-12,     # T (magnetometers)
                  eeg=40e-6,     # uV (EEG channels)
                  eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    smean = np.zeros(nsig)
    smold = np.zeros(nsig)
    rmean = np.zeros(nref)
    rmold = np.zeros(nref)
    sscov = 0
    srcov = 0
    rrcov = np.zeros((nref, nref))
    srcov = np.zeros((nsig, nref))
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            for isl in xrange(raw_segmentsig.shape[1]):
                nsl = isl + n_samples + 1
                cnslm1dnsl = float((nsl - 1)) / float(nsl)
                sslsubmean = (raw_segmentsig[:, isl] - smold)
                rslsubmean = (raw_segmentref[:, isl] - rmold)
                smean = smold + sslsubmean / nsl
                rmean = rmold + rslsubmean / nsl
                sscov += sslsubmean * (raw_segmentsig[:, isl] - smean)
                srcov += cnslm1dnsl * np.dot(sslsubmean.reshape((nsig, 1)), rslsubmean.reshape((1, nref)))
                rrcov += cnslm1dnsl * np.dot(rslsubmean.reshape((nref, 1)), rslsubmean.reshape((1, nref)))
                smold = smean
                rmold = rmean
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sscov /= (n_samples - 1)
    srcov /= (n_samples - 1)
    rrcov /= (n_samples - 1)
    print "Normalize srcov..."
    rrslope = copy.copy(rrcov)
    for iref in xrange(nref):
        dtmp = rrcov[iref][iref]
        if dtmp > TINY:
            srcov[:, iref] /= dtmp
            rrslope[:, iref] /= dtmp
        else:
            srcov[:, iref] = 0.
            rrslope[:, iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    print "Compare results with 'standard' values:"
    print "cmp(sigmean,smean):", np.allclose(smean, sigmean, atol=0.)
    print "cmp(refmean,rmean):", np.allclose(rmean, refmean, atol=0.)
    print "cmp(sscovdata,sscov):", np.allclose(sscov, sscovdata, atol=0.)
    print "cmp(srcovdata,srcov):", np.allclose(srcov, srcovdata, atol=0.)
    print "cmp(rrcovdata,rrcov):", np.allclose(rrcov, rrcovdata, atol=0.)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        print "########## Calculated initial signal channel covariance:"
        # Calculate initial signal channel covariance:
        # (only used as quality measure)
        print "initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscov))
        for i in xrange(5):
            print "initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscov.flatten()[i]))
        print " "
    if nref < 6:
        print "rrslope-entries:"
        for i in xrange(nref):
            print rrslope[i][:]

    U, s, V = np.linalg.svd(rrslope, full_matrices=True)
    print s

    print "Applying cutoff for smallest SVs:"
    dtmp = s.max() * SVD_RELCUTOFF
    sinv = np.zeros(nref)
    for i in xrange(nref):
        if abs(s[i]) >= dtmp:
            sinv[i] = 1. / s[i]
        else:
            s[i] = 0.
    # s *= (abs(s)>=dtmp)
    # sinv = ???
    print s
    stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
    print ">>> Testing svd-result: %s" % stat
    if not stat:
        print "    (Maybe due to SV-cutoff?)"

    # Solve for inverse coefficients:
    print ">>> Setting RRinvtr=U diag(sinv) V"
    RRinvtr = np.zeros((nref, nref))
    RRinvtr = np.dot(U, np.dot(np.diag(sinv), V))
    if checkresults:
        # print ">>> RRinvtr-result:"
        # print RRinvtr
        stat = np.allclose(np.identity(nref), np.dot(rrslope.transpose(), RRinvtr))
        if stat:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): ok"
        else:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): failed"
            print np.dot(rrslope.transpose(), RRinvtr)
            # np.less_equal(np.abs(np.dot(rrslope.transpose(),RRinvtr)-np.identity(nref)),0.01*np.ones((nref,nref)))
        print ""

    print "########## Calc weight matrix..."
    # weights-matrix will be somewhat larger than necessary,
    # (to simplify indexing in compensation loop):
    weights = np.zeros((raw._data.shape[0], nref))
    for isig in xrange(nsig):
        for iref in xrange(nref):
            weights[sigpick[isig]][iref] = np.dot(srcov[isig][:], RRinvtr[iref][:])

    if np.allclose(np.zeros(weights.shape), np.abs(weights), atol=1.e-8):
        print ">>> all weights are small (<=1.e-8)."
    else:
        print ">>> largest weight %12.5e" % np.max(np.abs(weights))
        wlrg = np.where(np.abs(weights) >= 0.99 * np.max(np.abs(weights)))
        for iwlrg in xrange(len(wlrg[0])):
            print ">>> weights[%3d,%2d] = %12.5e" % \
                  (wlrg[0][iwlrg], wlrg[1][iwlrg], weights[wlrg[0][iwlrg], wlrg[1][iwlrg]])

    if nref < 5:
        print "weights-entries for first sigchans:"
        for i in xrange(5):
            print 'weights[sp(%2d)][r]=[' % i + ' '.join([' %+10.7f' %
                             val for val in weights[sigpick[i]][:]]) + ']'

    print "########## Compensating signal channels:"
    tct = time.clock()
    twt = time.time()
    # data,times = raw[:,raw.time_as_index(tmin)[0]:raw.time_as_index(tmax)[0]:]
    # Work on entire data stream:
    for isl in xrange(raw._data.shape[1]):
        slice = np.take(raw._data, [isl], axis=1)
        if use_reffilter:
            refslice = np.take(fltref._data, [isl], axis=1)
            refarr = refslice[:].flatten() - rmean
            # refarr = fltres[:,isl]-rmean
        else:
            refarr = slice[refpick].flatten() - rmean
        subrefarr = np.dot(weights[:], refarr)
        # data[:,isl] -= subrefarr   will not modify raw._data?
        raw._data[:, isl] -= subrefarr
        if isl%10000 == 0:
            print "\rProcessed slice %6d" % isl
    print "\nDone."
    tc1 = time.clock()
    tw1 = time.time()
    print "compensation loop took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        print "########## Calculating final signal channel covariance:"
        # Calculate final signal channel covariance:
        # (only used as quality measure)
        tct = time.clock()
        twt = time.time()
        sigmean = 0
        sscovdata = 0
        n_samples = 0
        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            # Artifacts found here will probably differ from pre-noisered artifacts!
            if not exclart or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                        flat=None, ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                n_samples += raw_segmentsig.shape[1]
        sigmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        print ">>> no channel got worse: ", np.all(np.less_equal(sscovdata, sscovinit))
        print "final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
        for i in xrange(5):
            print "final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
        tc1 = time.clock()
        tw1 = time.time()
        print "signal covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))
        print " "

    nrname = dname[:dname.rfind('-raw.fif')] + ',nold-raw.fif'
    print "Saving '%s'..." % nrname
    raw.save(nrname, overwrite=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "Total run         took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))
예제 #5
0
def noise_reducer(fname_raw, raw=None, signals=[], noiseref=[], detrending=None,
                  tmin=None, tmax=None, reflp=None, refhp=None, refnotch=None,
                  exclude_artifacts=True, checkresults=True, return_raw=False,
                  complementary_signal=False, fnout=None, verbose=False):
    """
    Apply noise reduction to signal channels using reference channels.

    Parameters
    ----------
    fname_raw : (list of) rawfile name(s)
    raw : mne Raw objects
        Allows passing of (preloaded) raw object in addition to fname_raw
        or solely (use fname_raw=None in this case).
    signals : list of string
              List of channels to compensate using noiseref.
              If empty use the meg signal channels.
    noiseref : list of string | str
              List of channels to use as noise reference.
              If empty use the magnetic reference channsls (default).
    signals and noiseref may contain regexp, which are resolved
    using mne.pick_channels_regexp(). All other channels are copied.
    tmin : lower latency bound for weight-calc [start of trace]
    tmax : upper latency bound for weight-calc [ end  of trace]
           Weights are calc'd for (tmin,tmax), but applied to entire data set
    refhp : high-pass frequency for reference signal filter [None]
    reflp :  low-pass frequency for reference signal filter [None]
            reflp < refhp: band-stop filter
            reflp > refhp: band-pass filter
            reflp is not None, refhp is None: low-pass filter
            reflp is None, refhp is not None: high-pass filter
    refnotch : (list of) notch frequencies for reference signal filter [None]
               use raw(ref)-notched(ref) as reference signal
    exclude_artifacts: filter signal-channels thru _is_good() [True]
                       (parameters are at present hard-coded!)
    return_raw : bool
        If return_raw is true, the raw object is returned and raw file
        is not written to disk unless fnout is explicitly specified.
        It is suggested that this option be used in cases where the
        noise_reducer is applied multiple times. [False]
    fnout : explicit specification for an output file name [None]
        Automatic filenames replace '-raw.fif' by ',nr-raw.fif'.
    complementary_signal : replaced signal by traces that would be
                           subtracted [False]
                           (can be useful for debugging)
    detrending: boolean to ctrl subtraction of linear trend from all
                magn. chans [False]
    checkresults : boolean to control internal checks and overall success
                   [True]

    Outputfile
    ----------
    <wawa>,nr-raw.fif for input <wawa>-raw.fif

    Returns
    -------
    If return_raw is True, then mne.io.Raw instance is returned.

    Bugs
    ----
    - artifact checking is incomplete (and with arb. window of tstep=0.2s)
    - no accounting of channels used as signal/reference
    - non existing input file handled ungracefully
    """

    if type(complementary_signal) != bool:
        raise ValueError("Argument complementary_signal must be of type bool")

    # handle error if Raw object passed with file list
    if raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined with'
                         'one Raw object')

    # handle error if return_raw is requested with file list
    if return_raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined return_raw.'
                         'Please pass one file at a time.')

    # handle error if Raw object is passed with detrending option
    # TODO include perform_detrending for Raw objects
    if raw and detrending:
        raise ValueError('Please perform detrending on the raw file directly.'
                         'Cannot perform detrending on the raw object')

    # Handle combinations of fname_raw and raw object:
    if fname_raw is not None:
        fnraw = get_files_from_list(fname_raw)
        have_input_file = True
    elif raw is not None:
        if 'filename' in raw.info:
            fnraw = [os.path.basename(raw.filenames[0])]
        else:
            fnraw = raw._filenames[0]
        warnings.warn('Setting file name from Raw object')
        have_input_file = False
        if fnout is None and not return_raw:
            raise ValueError('Refusing to waste resources without result')
    else:
        raise ValueError('Refusing Creatio ex nihilo')

    # loop across all filenames
    for fname in fnraw:

        if verbose:
            print("########## Read raw data:")

        tc0 = time.clock()
        tw0 = time.time()

        if raw is None:
            if detrending:
                raw = perform_detrending(fname, save=False)
            else:
                raw = mne.io.Raw(fname, preload=True)
        else:
            # perform sanity check to make sure Raw object and file are same
            if 'filename' in raw.info:
                fnintern = [os.path.basename(raw.filenames[0])]
            else:
                fnintern = raw._filenames[0]
            if os.path.basename(fname) != os.path.basename(fnintern):
                warnings.warn('The file name within the Raw object and provided\n   '
                              'fname are not the same. Please check again.')

        tc1 = time.clock()
        tw1 = time.time()

        if verbose:
            print(">>> loading raw data took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tc0)), (tw1 - tw0)))

        # Time window selection
        # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
        # tstep is used in artifact detection
        # tmin,tmax variables must not be changed here!
        if tmin is None:
            itmin = 0
        else:
            itmin = int(floor(tmin * raw.info['sfreq']))
        if tmax is None:
            itmax = raw.last_samp - raw.first_samp
        else:
            itmax = int(ceil(tmax * raw.info['sfreq']))

        if itmax - itmin < 2:
            raise ValueError("Time-window for noise compensation empty or too short")

        if verbose:
            print(">>> Set time-range to [%7.3f,%7.3f]" % \
                  (raw.times[itmin], raw.times[itmax]))

        if signals is None or len(signals) == 0:
            sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False,
                                     eog=False, exclude='bads')
        else:
            sigpick = channel_indices_from_list(raw.info['ch_names'][:], signals,
                                                raw.info.get('bads'))
        nsig = len(sigpick)
        if nsig == 0:
            raise ValueError("No channel selected for noise compensation")

        if noiseref is None or len(noiseref) == 0:
            # References are not limited to 4D ref-chans, but can be anything,
            # incl. ECG or powerline monitor.
            if verbose:
                print(">>> Using all refchans.")
            refexclude = "bads"
            refpick = mne.pick_types(raw.info, ref_meg=True, meg=False,
                                     eeg=False, stim=False,
                                     eog=False, exclude='bads')
        else:
            refpick = channel_indices_from_list(raw.info['ch_names'][:],
                                                noiseref, raw.info.get('bads'))
        nref = len(refpick)
        if nref == 0:
            raise ValueError("No channel selected as noise reference")

        if verbose:
            print(">>> sigpick: %3d chans, refpick: %3d chans" % (nsig, nref))
        badpick = np.intersect1d(sigpick, refpick, assume_unique=False)
        if len(badpick) > 0:
            raise Warning("Intersection of signal and reference channels not empty")

        if reflp is None and refhp is None and refnotch is None:
            use_reffilter = False
            use_refantinotch = False
        else:
            use_reffilter = True
            if verbose:
                print("########## Filter reference channels:")

            use_refantinotch = False
            if refnotch is not None:
                if reflp is not None or reflp is not None:
                    raise ValueError("Cannot specify notch- and high-/low-pass"
                                     "reference filter together")
                nyquist = (0.5 * raw.info['sfreq'])
                if isinstance(refnotch, list):
                    notchfrqs = refnotch
                else:
                    notchfrqs = [refnotch]
                notchfrqscln = []
                for nfrq in notchfrqs:
                    if not isinstance(nfrq, float) and not isinstance(nfrq, int):
                        raise ValueError("Illegal entry for notch-frequency (", nfrq, ")")
                    if nfrq >= nyquist:
                        warnings.warn('Ignoring notch frequency > 0.5*sample_rate=%.1fHz' % nyquist)
                    else:
                        notchfrqscln.append(nfrq)
                if len(notchfrqscln) == 0:
                    raise ValueError("Notch frequency list is (now) empty")
                use_refantinotch = True
                if verbose:
                    print(">>> notches at freq ", end=' ')
                    print(notchfrqscln)
            else:
                if verbose:
                    if reflp is not None:
                        print(">>>  low-pass with cutoff-freq %.1f" % reflp)
                    if refhp is not None:
                        print(">>> high-pass with cutoff-freq %.1f" % refhp)

            # Adapt followg drop-chans cmd to use 'all-but-refpick'
            droplist = [raw.info['ch_names'][k] for k in range(raw.info['nchan']) if not k in refpick]
            tct = time.clock()
            twt = time.time()
            fltref = raw.copy().drop_channels(droplist)
            if use_refantinotch:
                rawref = raw.copy().drop_channels(droplist)
                fltref.notch_filter(notchfrqscln, fir_design='firwin',
                                    fir_window='hann', phase='zero',
                                    picks=np.array(list(range(nref))),
                                    method='fir')
                fltref._data = (rawref._data - fltref._data)
            else:
                fltref.filter(refhp, reflp, fir_design='firwin',
                              fir_window='hann', phase='zero',
                              picks=np.array(list(range(nref))),
                              method='fir')
            tc1 = time.clock()
            tw1 = time.time()
            if verbose:
                print(">>> filtering ref-chans  took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tct)),
                                                                                           (tw1 - twt)))

        if verbose:
            print("########## Calculating sig-ref/ref-ref-channel covariances:")
        # Calculate sig-ref/ref-ref-channel covariance:
        # (there is no need to calc inter-signal-chan cov,
        #  but there seems to be no appropriat fct available)
        # Here we copy the idea from compute_raw_data_covariance()
        # and truncate it as appropriate.
        tct = time.clock()
        twt = time.time()
        # The following reject and infosig entries are only
        # used in _is_good-calls.
        # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
        # ignore ref-channels (not covered by dict) and checks individual
        # data segments - artifacts across a buffer boundary are not found.
        reject = dict(grad=4000e-13,  # T / m (gradiometers)
                      mag=4e-12,  # T (magnetometers)
                      eeg=40e-6,  # uV (EEG channels)
                      eog=250e-6)  # uV (EOG channels)

        infosig = copy.copy(raw.info)
        infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
        # the below fields are *NOT* (190103) updated automatically when 'chs' is updated
        infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
        infosig['nchan'] = len(sigpick)
        idx_by_typesig = channel_indices_by_type(infosig)

        # Read data in chunks:
        tstep = 0.2
        itstep = int(ceil(tstep * raw.info['sfreq']))
        sigmean = 0
        refmean = 0
        sscovdata = 0
        srcovdata = 0
        rrcovdata = 0
        n_samples = 0

        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            if use_reffilter:
                raw_segmentref, times = fltref[:, first:last]
            else:
                raw_segmentref, times = raw[refpick, first:last]

            if not exclude_artifacts or \
                    _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
                             ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                refmean += raw_segmentref.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
                rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
                n_samples += raw_segmentsig.shape[1]
            else:
                logger.info("Artefact detected in [%d, %d]" % (first, last))
        if n_samples <= 1:
            raise ValueError('Too few samples to calculate weights')
        sigmean /= n_samples
        refmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
        srcovdata /= (n_samples - 1)
        rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
        rrcovdata /= (n_samples - 1)
        sscovinit = np.copy(sscovdata)
        if verbose:
            print(">>> Normalize srcov...")

        rrslope = copy.copy(rrcovdata)
        for iref in range(nref):
            dtmp = rrcovdata[iref, iref]
            if dtmp > TINY:
                srcovdata[:, iref] /= dtmp
                rrslope[:, iref] /= dtmp
            else:
                srcovdata[:, iref] = 0.
                rrslope[:, iref] = 0.

        if verbose:
            print(">>> Number of samples used : %d" % n_samples)
            tc1 = time.clock()
            tw1 = time.time()
            print(">>> sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt)))

        if checkresults:
            if verbose:
                print("########## Calculated initial signal channel covariance:")
                # Calculate initial signal channel covariance:
                # (only used as quality measure)
                print(">>> initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata)))
                for i in range(min(5, nsig)):
                    print(">>> initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                for i in range(max(0, nsig - 5), nsig):
                    print(">>> initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                print(">>>")

        U, s, V = np.linalg.svd(rrslope, full_matrices=True)
        if verbose:
            print(">>> singular values:")
            print(s)
            print(">>> Applying cutoff for smallest SVs:")

        dtmp = s.max() * SVD_RELCUTOFF
        s *= (abs(s) >= dtmp)
        sinv = [1. / s[k] if s[k] != 0. else 0. for k in range(nref)]
        if verbose:
            print(">>> singular values (after cutoff):")
            print(s)

        stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
        if verbose:
            print(">>> Testing svd-result: %s" % stat)
            if not stat:
                print("    (Maybe due to SV-cutoff?)")

        # Solve for inverse coefficients:
        # Set RRinv.tr=U diag(sinv) V
        RRinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
        if checkresults:
            stat = np.allclose(np.identity(nref), np.dot(RRinv, rrslope))
            if stat:
                if verbose:
                    print(">>> Testing RRinv-result (should be unit-matrix): ok")
            else:
                print(">>> Testing RRinv-result (should be unit-matrix): failed")
                print(np.transpose(np.dot(RRinv, rrslope)))
                print(">>>")

        if verbose:
            print("########## Calc weight matrix...")

        # weights-matrix will be somewhat larger than necessary,
        # (to simplify indexing in compensation loop):
        weights = np.zeros((raw._data.shape[0], nref))
        for isig in range(nsig):
            for iref in range(nref):
                weights[sigpick[isig], iref] = np.dot(srcovdata[isig, :], RRinv[:, iref])

        if verbose:
            print("########## Compensating signal channels:")
            if complementary_signal:
                print(">>> Caveat: REPLACING signal by compensation signal")

        tct = time.clock()
        twt = time.time()

        # Work on entire data stream:
        for isl in range(raw._data.shape[1]):
            slice = np.take(raw._data, [isl], axis=1)
            if use_reffilter:
                refslice = np.take(fltref._data, [isl], axis=1)
                refarr = refslice[:].flatten() - refmean
                # refarr = fltres[:,isl]-refmean
            else:
                refarr = slice[refpick].flatten() - refmean
            subrefarr = np.dot(weights[:], refarr)

            if not complementary_signal:
                raw._data[:, isl] -= subrefarr
            else:
                raw._data[:, isl] = subrefarr

            if (isl % 10000 == 0 or isl + 1 == raw._data.shape[1]) and verbose:
                print("\rProcessed slice %6d" % isl, end=" ")
                sys.stdout.flush()

        if verbose:
            print("\nDone.")
            tc1 = time.clock()
            tw1 = time.time()
            print(">>> compensation loop took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tct)), (tw1 - twt)))

        if checkresults:
            if verbose:
                print("########## Calculating final signal channel covariance:")
            # Calculate final signal channel covariance:
            # (only used as quality measure)
            tct = time.clock()
            twt = time.time()
            sigmean = 0
            sscovdata = 0
            n_samples = 0
            for first in range(itmin, itmax, itstep):
                last = first + itstep
                if last >= itmax:
                    last = itmax
                raw_segmentsig, times = raw[sigpick, first:last]
                # Artifacts found here will probably differ from pre-noisered artifacts!
                if not exclude_artifacts or \
                        _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                                 flat=None, ignore_chs=raw.info['bads']):
                    sigmean += raw_segmentsig.sum(axis=1)
                    sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                    n_samples += raw_segmentsig.shape[1]
            if n_samples <= 1:
                raise ValueError('Too few samples to calculate final signal channel covariance')
            sigmean /= n_samples
            sscovdata -= n_samples * sigmean[:] * sigmean[:]
            sscovdata /= (n_samples - 1)
            if verbose:
                print(">>> no channel got worse: %s" % str(np.all(np.less_equal(sscovdata, sscovinit))))
                print(">>> final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata)))
                for i in range(min(5, nsig)):
                    print(">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                # for i in range(min(5,nsig),max(0,nsig-5)):
                #    print(">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                for i in range(max(0, nsig - 5), nsig):
                    print(">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                tc1 = time.clock()
                tw1 = time.time()
                print(">>> signal covar-calc took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tct)),
                                                                                        (tw1 - twt)))
                print(">>>")

        if fnout is not None:
            fnoutloc = fnout
        elif return_raw:
            fnoutloc = None
        elif have_input_file:
            fnoutloc = fname[:fname.rfind('-raw.fif')] + ',nr-raw.fif'
        else:
            fnoutloc = None

        if fnoutloc is not None:
            if verbose:
                print(">>> Saving '%s'..." % fnoutloc)
            raw.save(fnoutloc, overwrite=True)

        tc1 = time.clock()
        tw1 = time.time()
        if verbose:
            print(">>> Total run took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tc0)), (tw1 - tw0)))

        if return_raw:
            if verbose:
                print(">>> Returning raw object...")
            return raw
예제 #6
0
def preproc1epoch(eeg,
                  info,
                  projs=[],
                  SSP=True,
                  reject=None,
                  mne_reject=1,
                  reject_ch=None,
                  flat=None,
                  bad_channels=[],
                  opt_detrend=1,
                  HP=0,
                  LP=40,
                  phase='zero-double'):
    '''    
    Preprocesses EEG data epoch-wise. 
    
    # Arguments
        eeg: numPy array
            EEG epoch in the following format: [time samples, channels].
        
        info: MNE info structure. 
            Predefined info structure. Can be generated using createInfoMNE function.
            
        projs: list
            MNE SSP projector objects. Used if SSP = True. 
            
        SSP: boolean
            Whether to apply SSP projectors (artefact correction) to the EEG epoch.
            
        reject: boolean
            Whether to reject channels, either manually defined or based on MNE analysis.
            
        mne_reject: boolean
            Whether to use MNE rejection based on the built-in function: epochs._is_good. 
            
        reject_ch: boolean
            Whether to reject nine predefined channels (can be changed to any channels).
            
        flat: boolean
            Input for the MNE built-in function: epochs._is_good. See function documentation.
            
        bad_channels: list
            Input for the MNE built-in function: epochs._is_good. Manual rejection of channels. See function documentation.
            
        opt_detrend: boolean
            Whether to apply temporal EEG detrending (linear).
        
        HP: int
            High-pass filter cut-off, default 0 Hz.
        
        LP: int
            Low-pass filter cut-off, default 40 Hz.

        phase: string
            FIR filter phase (refer to MNE filtering function for options), default 'zero-double'.

    
    # Preprocessing steps - based on inputs 
    
        Linear temporal detrending
        
        Initial rejection of pre-defined channels 
        
        Bandpass filtering (currently 0-40 Hz, defined by variables: LP, HP, phase)
        
        Resampling to 100 Hz
        
        SSP artefact correction 
        
        Analysis and rejection of bad channels
            Interpolation of bad channels
            
        Average re-referencing
        
        Baseline correction
    
    # Returns
        epoch: NumPy array
            Preprocessed EEG epoch in NumPy array.
    
    '''

    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T, (1, n_channels, n_samples))
    tmin = -0.1  # Baseline start, i.e. 100 ms before stimulus onset

    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')

    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)

    # Drop list of channels known to be problematic:
    if reject_ch == True:
        bads = ['Fp1', 'Fp2', 'Fz', 'AF3', 'AF4', 'T7', 'T8', 'F7', 'F8']
        epoch.drop_channels(bads)

    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)

    # Downsample
    epoch.resample(100, npad='auto', verbose=False)

    # Apply baseline correction
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    # Apply SSP projectors
    if SSP == True:
        epoch.add_proj(projs)
        epoch.apply_proj()

    if reject is not None:  # Rejection of channels, either manually defined or based on MNE analysis. Currently not used.
        if mne_reject == 1:  # Use MNE method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type
            # reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A, bad_channels = _is_good(epoch.get_data()[0],
                                       epoch.ch_names,
                                       channel_type_idx=idx_by_type,
                                       reject=reject,
                                       flat=flat,
                                       full_report=True)
            print(A)
            if A == False:
                epoch.info['bads'] = bad_channels
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else:  # Predefined bad_channels
            epoch.drop_channels(bad_channels)

    # Re-referencing
    epoch.set_eeg_reference(verbose=False)

    # Apply baseline after rereference
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    epoch = epoch.get_data()[0]

    return epoch
예제 #7
0
def preproc1epoch(eeg,
                  info,
                  projs=[],
                  SSP=True,
                  reject=None,
                  mne_reject=1,
                  reject_ch=None,
                  flat=None,
                  bad_channels=[],
                  opt_detrend=1):
    '''
    Preprocesses epoched EEG data.
    
    # Input
    - eeg: Epoched EEG data in the following format: (trials, time samples, channels).
    - info: predefined info containing channels etc.
    - projs: used if SSP=True. SSP projectors
    
    # Preprocessing
    - EpochsArray format in MNE (with initial baseline correction)
    - Bandpass filter (0-40Hz)
    - Resample to 100Hz
    - SSP (if True)
    - Reject bad channels
        - interpolate bad channels
    - Rereference to average
    - Baseline correction
    
    # Output
    - Epoched preprocessed EEG data in np array.
    
    '''

    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T, (1, n_channels, n_samples))
    tmin = -0.1  # start baseline at

    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')

    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)

    # Drop list of channels known to be problematic:
    if reject_ch == True:
        bads = ['Fp1', 'Fp2', 'Fz', 'AF3', 'AF4', 'T7', 'T8', 'F7', 'F8']
        epoch.drop_channels(bads)

    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)

    # Downsample
    epoch.resample(100, npad='auto', verbose=False)

    # Apply baseline correction
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    # Apply SSP prejectors
    if SSP == True:
        # Apply projection to the epochs already defined
        epoch.add_proj(projs)
        epoch.apply_proj()

    if reject is not None:  # currently not used
        if mne_reject == 1:  # use mne method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type
            #reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A, bad_channels = _is_good(epoch.get_data()[0],
                                       epoch.ch_names,
                                       channel_type_idx=idx_by_type,
                                       reject=reject,
                                       flat=flat,
                                       full_report=True)
            print(A)
            if A == False:
                epoch.info['bads'] = bad_channels
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else:  # bad_channels is predefined
            epoch.drop_channels(bad_channels)

    # Rereferencing
    epoch.set_eeg_reference(verbose=False)
    # Apply baseline after rereference
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    epoch = epoch.get_data()[0]
    return epoch
예제 #8
0
def preproc1epoch(eeg,info,projs=[],SSP=True,reject=None,mne_reject=1,reject_ch=None,flat=None,bad_channels=[],opt_detrend=1):

    '''
    Preprocesses epoched EEG data.
    
    # Input
    - eeg: numPy array. EEG epoch in the following format: (time samples, channels).
    - info: MNE info data structure. Predefined info containing channels etc. Can be generated using create_info_mne function.
    - projs: MNE SSP projector objects. Used if SSP = True. 
    - reject: bool. Whether to reject channels, either manually defined or based on MNE analysis.
    - reject_ch: bool. Whether to reject nine predefined channels.
    - mne_reject: bool. Whether to use MNE rejection based on epochs._is_good. 
    - flat: bool??. Input for MNE rejection
    - bad_channels: list. Manual rejection of channels.
    - opt_detrend: bool. Whether to apply temporal EEG detrending (linear).
    
    # Preprocessing steps - based on inputs 
    - Linear temporal detrending
    - Rejection of initial, predefined channels 
    - Bandpass filter (0-40Hz)
    - Resample to 100Hz
    - SSP correction 
    - Rejection of bad channels
        - Interpolation of bad channels
    - Rereference to average
    - Baseline correction
    
    # Output
    - Epoched preprocessed EEG data in numPy array.
    
    '''
    
    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T,(1,n_channels,n_samples))
    tmin = -0.1 # Baseline start 
    
    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')
        
    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)
    
    # Drop list of channels known to be problematic:
    if reject_ch == True: 
        bads =  ['Fp1','Fp2','Fz','AF3','AF4','T7','T8','F7','F8']
        epoch.drop_channels(bads)
    
    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)
    
    # Downsample
    epoch.resample(100, npad='auto',verbose=False)
    
    # Apply baseline correction
    epoch.apply_baseline(baseline=(None,0),verbose=False)
    
    # Apply SSP projectors
    if SSP == True:
        # Apply projection to the epochs already defined
        epoch.add_proj(projs)
        epoch.apply_proj()
        
    if reject is not None: # Rejection of channels, either manually defined or based on MNE analysis. Currently not used.
        if mne_reject == 1: # Use MNE method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type    
            #reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A,bad_channels = _is_good(epoch.get_data()[0], epoch.ch_names, channel_type_idx=idx_by_type,reject=reject, flat=flat, full_report=True)
            print(A)
            if A == False:
                epoch.info['bads']=bad_channels    
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else: # Predfined bad_channels 
            epoch.drop_channels(bad_channels)
    
    # Rereferencing
    epoch.set_eeg_reference(verbose=False)
    
    # Apply baseline after rereference
    epoch.apply_baseline(baseline=(None,0),verbose=False)
        
    epoch = epoch.get_data()[0]
    
    return epoch
def preprocessEpoch(eeg, info, downsample, tmin, reject=None, mne_reject=1, reject_ch=None, flat=None, bad_channels=[],
                    opt_detrend=1, HP=0, LP=40, phase='zero-double'):
    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T, (1, n_channels, n_samples))
    # Baseline start, i.e. 200 ms before stimulus onset

    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')

    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)

    # Drop list of channels known to be problematic:
    if reject_ch == True:
        # label of channels to remove
        bads = ['RAW_CQ', 'GYROX', 'GYROY', 'TIMESTAMP']
        badSet = set(bads)

        # list of all channel names
        allSet = set(epoch.ch_names)

        # find the intersection of all available channels and bad channels
        badSet = badSet.intersection(allSet)
        badSet = list(badSet)
        epoch.drop_channels(badSet)

    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)

    # Downsample
    epoch.resample(downsample, npad='auto', verbose=False)

    # Apply baseline correction
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    if reject is not None:  # Rejection of channels, either manually defined or based on MNE analysis. Currently not
        # used.
        if mne_reject == 1:  # Use MNE method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type
            # reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A, bad_channels = _is_good(epoch.get_data()[0], epoch.ch_names, channel_type_idx=idx_by_type, reject=reject,
                                       flat=flat, full_report=True)
            print(A)
            if A == False:
                epoch.info['bads'] = bad_channels
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else:  # Predefined bad_channels
            epoch.drop_channels(bad_channels)

    # Re-referencing
    epoch.set_eeg_reference(verbose=False)

    # Apply baseline after re-reference
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    epoch = epoch.get_data()[0]

    return epoch