Example #1
0
def test_labeling():
    "Test cluster labeling"
    shape = flat_shape = (4, 20)
    pmap = np.empty(shape, np.float_)
    edges = np.array([(0, 1), (0, 3), (1, 2), (2, 3)], np.uint32)
    conn = Connectivity((Scalar('graph', range(4),
                                connectivity=edges), UTS(0, 0.01, 20)))
    criteria = None

    # some clusters
    pmap[:] = [[3, 3, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
               [0, 1, 0, 0, 0, 0, 8, 0, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 4, 0],
               [0, 3, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4, 4],
               [0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 4, 4, 4, 0, 0, 0, 0, 0]]
    cmap, cids = label_clusters(pmap, 2, 0, conn, criteria)
    assert_equal(len(cids), 6)
    assert_array_equal(cmap > 0, np.abs(pmap) > 2)

    # some other clusters
    pmap[:] = [[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0],
               [0, 4, 0, 0, 0, 0, 0, 4, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0],
               [0, 0, 4, 4, 0, 4, 4, 0, 4, 0, 0, 0, 4, 4, 1, 0, 4, 4, 0, 0],
               [0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 0, 0, 0]]
    cmap, cids = label_clusters(pmap, 2, 0, conn, criteria)
    assert_equal(len(cids), 6)
    assert_array_equal(cmap > 0, np.abs(pmap) > 2)
Example #2
0
def test_cast_to_ndvar():
    "Test table.cast_to_ndvar()"
    long_ds = datasets.get_uv()
    long_ds['scalar'] = long_ds['A'] == 'a2'
    long_ds['time'] = long_ds.eval('A%B').as_var({
        ('a1', 'b1'): 0.,
        ('a1', 'b2'): 0.1,
        ('a2', 'b1'): 0.2,
        ('a2', 'b2'): 0.3,
    })

    # categorial
    ds = table.cast_to_ndvar('fltvar', 'A', 'B%rm', ds=long_ds, name='new')
    assert ds.n_cases == long_ds.n_cases / 2
    assert ds['new'].A == Categorial('A', ('a1', 'a2'))

    # scalar
    ds2 = table.cast_to_ndvar('fltvar', 'scalar', 'B%rm', ds=long_ds, dim='newdim', name='new')
    assert ds2.n_cases == long_ds.n_cases / 2
    assert ds2['new'].newdim == Scalar('newdim', [False, True])
    assert_array_equal(ds['new'].x, ds2['new'].x)

    # time
    ds = table.cast_to_ndvar('fltvar', 'time', 'rm', ds=long_ds, dim='uts', name='y')
    assert ds.n_cases == long_ds.n_cases / 4
    assert ds['y'].time == UTS(0, 0.1, 4)
def test_find_intervals():
    time = UTS(-5, 1, 10)
    x = NDVar([0, 1, 0, 1, 1, 0, 1, 1, 1, 0], (time,))
    eq_(find_intervals(x), ((-4, -3), (-2, 0), (1, 4)))
    x = NDVar([0, 1, 0, 1, 1, 0, 1, 1, 1, 1], (time,))
    eq_(find_intervals(x), ((-4, -3), (-2, 0), (1, 5)))
    x = NDVar([1, 1, 0, 1, 1, 0, 1, 1, 1, 1], (time,))
    eq_(find_intervals(x), ((-5, -3), (-2, 0), (1, 5)))
Example #4
0
    def _load(self, path, tmin, tstep, n_samples, code, seed):
        x = load.unpickle(path)
        # allow for pre-computed resampled versions
        if isinstance(x, list):
            xs = x
            for x in xs:
                if x.time.tstep == tstep:
                    break
            else:
                raise IOError(
                    f"{os.path.basename(path)} does not contain tstep={tstep!r}"
                )
        # continuous UTS
        if isinstance(x, NDVar):
            if x.time.tstep == tstep:
                pass
            elif self.resample == 'bin':
                x = x.bin(tstep, label='start')
            elif self.resample == 'resample':
                srate = 1 / tstep
                int_srate = int(round(srate))
                srate = int_srate if abs(int_srate - srate) < .001 else srate
                x = resample(x, srate)
            elif self.resample is None:
                raise RuntimeError(
                    f"{os.path.basename(path)} has tstep={x.time.tstep}, not {tstep}"
                )
            else:
                raise RuntimeError(f"resample={self.resample!r}")
            x = pad(x, tmin, nsamples=n_samples)
        # NUTS
        elif isinstance(x, Dataset):
            ds = x
            if code.shuffle in ('permute', 'relocate'):
                rng = numpy.random.RandomState(seed)
                if code.shuffle == 'permute':
                    index = ds['permute'].x
                    assert index.dtype.kind == 'b'
                    values = ds[index, 'value'].x
                    rng.shuffle(values)
                    ds[index, 'value'] = values
                else:
                    rng.shuffle(ds['value'].x)
                code.register_shuffle()
            x = NDVar(numpy.zeros(n_samples),
                      UTS(tmin, tstep, n_samples),
                      name=code.code_with_rand)
            ds = ds[ds['time'] < x.time.tstop]
            for t, v in ds.zip('time', 'value'):
                x[t] = v
        else:
            raise TypeError(f'{x!r} at {path}')

        if code.shuffle in NDVAR_SHUFFLE_METHODS:
            x = shuffle(x, code.shuffle, code.shuffle_band, code.shuffle_angle)
            code.register_shuffle()
        return x
Example #5
0
def test_resample():
    x = NDVar([0.0, 1.0, 1.4, 1.0, 0.0],
              UTS(0, 0.1, 5)).mask([True, False, False, False, True])
    y = resample(x, 20)
    assert_array_equal(
        y.x.mask,
        [True, False, False, False, False, False, False, False, True, True])
    y = resample(x, 20, npad=0)
    assert_array_equal(
        y.x.mask,
        [True, False, False, False, False, False, False, False, True, True])
def test_find_peaks():
    scalar = Scalar('scalar', range(9))
    time = UTS(0, .1, 12)
    v = NDVar(np.zeros((9, 12)), (scalar, time))
    wsize = [0, 0, 1, 2, 3, 2, 1, 0, 0]
    for i, s in enumerate(wsize):
        if s:
            v.x[i, 5 - s: 5 + s] += np.hamming(2 * s)

    peaks = find_peaks(v)
    x, y = np.where(peaks.x)
    assert_array_equal(x, [4])
    assert_array_equal(y, [5])
Example #7
0
def test_frequency_response():
    b_array = signal.firwin(80, 0.5, window=('kaiser', 8))
    freqs_array, fresp_array = signal.freqz(b_array)
    hz_to_rad = 2 * np.pi * 0.01

    b = NDVar(b_array, (UTS(0, 0.01, 80), ))
    fresp = frequency_response(b)
    assert_array_equal(fresp.x, fresp_array)
    assert_array_equal(fresp.frequency.values * hz_to_rad, freqs_array)

    b2d = concatenate((b, b), Case)
    fresp = frequency_response(b2d)
    assert_array_equal(fresp.x[0], fresp_array)
    assert_array_equal(fresp.x[1], fresp_array)
    assert_array_equal(fresp.frequency.values * hz_to_rad, freqs_array)
 def _generate_continuous(
     self,
     uts: UTS,  # time axis for the output
     ds: Dataset,  # events
     stim_var: str,
     code: Code,
     directory: Path,
 ):
     # place multiple input files into a continuous predictor
     cache = {
         stim: self._load(uts.tstep,
                          code.with_stim(stim).nuts_file_name(self.columns),
                          directory)
         for stim in ds[stim_var].cells
     }
     # determine type
     stim_type = {type(s) for s in cache.values()}
     assert len(stim_type) == 1
     stim_type = stim_type.pop()
     # generate x
     if stim_type is Dataset:
         dss = []
         for t, stim in ds.zip('T_relative', stim_var):
             x = cache[stim].copy()
             x['time'] += t
             dss.append(x)
             if code.nuts_method:
                 x_stop_ds = t_stop_ds(x, t)
                 dss.append(x_stop_ds)
         x = self._ds_to_ndvar(combine(dss), uts, code)
     elif stim_type is NDVar:
         v = cache[ds[0, stim_var]]
         dimnames = v.get_dimnames(first='time')
         dims = (uts, *v.get_dims(dimnames[1:]))
         x = NDVar.zeros(dims, code.key)
         for t, stim in ds.zip('T_relative', stim_var):
             x_stim = cache[stim]
             i_start = uts._array_index(t + x_stim.time.tmin)
             i_stop = i_start + len(x_stim.time)
             if i_stop > len(uts):
                 raise ValueError(
                     f"{code.string_without_rand} for {stim} is longer than the data"
                 )
             x.x[i_start:i_stop] = x_stim.get_data(dimnames)
     else:
         raise RuntimeError(f"stim_type={stim_type!r}")
     return x
    def _generate(self, tmin: float, tstep: float, n_samples: int, code: Code,
                  directory: Path):
        # predictor for one input file
        file_name = code.nuts_file_name(self.columns)
        x = self._load(tstep, file_name, directory)
        if isinstance(x, Dataset):
            if n_samples is None:
                n_samples = int((x.info['tstop'] - tmin) // tstep)
            uts = UTS(tmin, tstep, n_samples)
            x = self._ds_to_ndvar(x, uts, code)
        elif isinstance(x, NDVar):
            x = pad(x, tmin, nsamples=n_samples, set_tmin=True)
        else:
            raise RuntimeError(x)

        if code.shuffle in NDVAR_SHUFFLE_METHODS:
            x = shuffle(x, code.shuffle, code.shuffle_index,
                        code.shuffle_angle)
            code.register_shuffle(index=True)
        return x
Example #10
0
    def h(self):
        """The spatio-temporal response function as (list of) NDVar"""
        n_vars = sum(len(dim) if dim else 1 for dim in self._stim_dims)
        if n_vars > 1:
            shape = (self.theta.shape[0], n_vars, -1)
            trf = self.theta.reshape(shape)
            trf = trf.swapaxes(1, 0)
        else:
            trf = self.theta[np.newaxis, :]

        trf = np.dot(trf, self._basis.T) / self.lead_field_scaling

        time = UTS(self.tstart, self.tstep, trf.shape[-1])
        if self.space:
            shared_dims = (self.source, self.space, time)
        else:
            shared_dims = (self.source, time)
        trf = trf.reshape((-1, *(map(len, shared_dims))))

        h = []
        i = 0
        for dim, name in zip(self._stim_dims, self._stim_names):
            if dim:
                dims = (dim, *shared_dims)
                i1 = i + len(dim)
                x = trf[i:i1]
                i = i1
            else:
                dims = shared_dims
                x = trf[i]
                i += 1
            h.append(NDVar(x, dims, name=name))

        if self._stim_is_single:
            return h[0]
        else:
            return h
Example #11
0
def test_clusterdist():
    "Test _ClusterDist class"
    shape = (10, 6, 6, 4)
    locs = [[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]]
    x = np.random.normal(0, 1, shape)
    sensor = Sensor(locs, ['0', '1', '2', '3'])
    sensor.set_connectivity(connect_dist=1.1)
    dims = ('case', UTS(-0.1, 0.1, 6), Scalar('dim2', range(6),
                                              'unit'), sensor)
    y = NDVar(x, dims)

    # test connecting sensors
    logging.info("TEST:  connecting sensors")
    bin_map = np.zeros(shape[1:], dtype=np.bool8)
    bin_map[:3, :3, :2] = True
    pmap = np.random.normal(0, 1, shape[1:])
    np.clip(pmap, -1, 1, pmap)
    pmap[bin_map] = 2
    cdist = _ClusterDist(y, 0, 1.5)
    print(repr(cdist))
    cdist.add_original(pmap)
    print(repr(cdist))
    assert_equal(cdist.n_clusters, 1)
    assert_array_equal(cdist._original_cluster_map == cdist._cids[0],
                       cdist._crop(bin_map).swapaxes(0, cdist._nad_ax))
    assert_equal(cdist.parameter_map.dims, y.dims[1:])

    # test connecting many sensors
    logging.info("TEST:  connecting sensors")
    bin_map = np.zeros(shape[1:], dtype=np.bool8)
    bin_map[:3, :3] = True
    pmap = np.random.normal(0, 1, shape[1:])
    np.clip(pmap, -1, 1, pmap)
    pmap[bin_map] = 2
    cdist = _ClusterDist(y, 0, 1.5)
    cdist.add_original(pmap)
    assert_equal(cdist.n_clusters, 1)
    assert_array_equal(cdist._original_cluster_map == cdist._cids[0],
                       cdist._crop(bin_map).swapaxes(0, cdist._nad_ax))

    # test keeping sensors separate
    logging.info("TEST:  keeping sensors separate")
    bin_map = np.zeros(shape[1:], dtype=np.bool8)
    bin_map[:3, :3, 0] = True
    bin_map[:3, :3, 2] = True
    pmap = np.random.normal(0, 1, shape[1:])
    np.clip(pmap, -1, 1, pmap)
    pmap[bin_map] = 2
    cdist = _ClusterDist(y, 1, 1.5)
    cdist.add_original(pmap)
    assert_equal(cdist.n_clusters, 2)

    # criteria
    ds = datasets.get_uts(True)
    res = testnd.ttest_rel('utsnd',
                           'A',
                           match='rm',
                           ds=ds,
                           samples=0,
                           pmin=0.05)
    assert_less(res.clusters['duration'].min(), 0.01)
    eq_(res.clusters['n_sensors'].min(), 1)
    res = testnd.ttest_rel('utsnd',
                           'A',
                           match='rm',
                           ds=ds,
                           samples=0,
                           pmin=0.05,
                           mintime=0.02,
                           minsensor=2)
    assert_greater_equal(res.clusters['duration'].min(), 0.02)
    eq_(res.clusters['n_sensors'].min(), 2)

    # 1d
    res1d = testnd.ttest_rel('utsnd.sub(time=0.1)',
                             'A',
                             match='rm',
                             ds=ds,
                             samples=0,
                             pmin=0.05)
    assert_dataobj_equal(res1d.p_uncorrected, res.p_uncorrected.sub(time=0.1))

    # TFCE
    logging.info("TEST:  TFCE")
    sensor = Sensor(locs, ['0', '1', '2', '3'])
    sensor.set_connectivity(connect_dist=1.1)
    time = UTS(-0.1, 0.1, 4)
    scalar = Scalar('scalar', range(10), 'unit')
    dims = ('case', time, sensor, scalar)
    np.random.seed(0)
    y = NDVar(np.random.normal(0, 1, (10, 4, 4, 10)), dims)
    cdist = _ClusterDist(y, 3, None)
    cdist.add_original(y.x[0])
    cdist.finalize()
    assert_equal(cdist.dist.shape, (3, ))
    # I/O
    string = pickle.dumps(cdist, pickle.HIGHEST_PROTOCOL)
    cdist_ = pickle.loads(string)
    assert_equal(repr(cdist_), repr(cdist))
    # find peaks
    x = np.array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [7, 7, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
                  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [5, 7, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 6, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
                  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 7, 5, 5, 0, 0],
                   [0, 0, 0, 0, 5, 4, 4, 4, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
                  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 4, 0, 0],
                   [0, 0, 0, 0, 7, 0, 0, 3, 0, 0],
                   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])
    tgt = np.equal(x, 7)
    peaks = find_peaks(x, cdist._connectivity)
    logging.debug(' detected: \n%s' % (peaks.astype(int)))
    logging.debug(' target: \n%s' % (tgt.astype(int)))
    assert_array_equal(peaks, tgt)
    # testnd permutation result
    res = testnd.ttest_1samp(y, tfce=True, samples=3)
    assert_allclose(np.sort(res._cdist.dist),
                    [77.5852307, 119.1976153, 217.6270428])

    # parc with TFCE on unconnected dimension
    configure(False)
    x = np.random.normal(0, 1, (10, 5, 2, 4))
    time = UTS(-0.1, 0.1, 5)
    categorial = Categorial('categorial', ('a', 'b'))
    y = NDVar(x, ('case', time, categorial, sensor))
    y0 = NDVar(x[:, :, 0], ('case', time, sensor))
    y1 = NDVar(x[:, :, 1], ('case', time, sensor))
    res = testnd.ttest_1samp(y, tfce=True, samples=3)
    res_parc = testnd.ttest_1samp(y, tfce=True, samples=3, parc='categorial')
    res0 = testnd.ttest_1samp(y0, tfce=True, samples=3)
    res1 = testnd.ttest_1samp(y1, tfce=True, samples=3)
    # cdist
    eq_(res._cdist.shape, (4, 2, 5))
    # T-maps don't depend on connectivity
    assert_array_equal(res.t.x[:, 0], res0.t.x)
    assert_array_equal(res.t.x[:, 1], res1.t.x)
    assert_array_equal(res_parc.t.x[:, 0], res0.t.x)
    assert_array_equal(res_parc.t.x[:, 1], res1.t.x)
    # TFCE-maps should always be the same because they're unconnected
    assert_array_equal(res.tfce_map.x[:, 0], res0.tfce_map.x)
    assert_array_equal(res.tfce_map.x[:, 1], res1.tfce_map.x)
    assert_array_equal(res_parc.tfce_map.x[:, 0], res0.tfce_map.x)
    assert_array_equal(res_parc.tfce_map.x[:, 1], res1.tfce_map.x)
    # Probability-maps should depend on what is taken into account
    p_a = res0.compute_probability_map().x
    p_b = res1.compute_probability_map().x
    assert_array_equal(res_parc.compute_probability_map(categorial='a').x, p_a)
    assert_array_equal(res_parc.compute_probability_map(categorial='b').x, p_b)
    p_parc = res_parc.compute_probability_map()
    assert_array_equal(p_parc.x, res.compute_probability_map().x)
    ok_(np.all(p_parc.sub(categorial='a').x >= p_a))
    ok_(np.all(p_parc.sub(categorial='b').x >= p_b))
    configure(True)
Example #12
0
def gammatone_bank(
        wav: NDVar,
        f_min: float,
        f_max: float,
        n: int,
        integration_window: float = 0.010,
        tstep: float = None,
        location: str = 'right',
        pad: bool = True,
        name: str = None,
) -> NDVar:
    """Gammatone filterbank response

    Parameters
    ----------
    wav : NDVar
        Sound input.
    f_min : scalar
        Lower frequency cutoff.
    f_max : scalar
        Upper frequency cutoff.
    n : int
        Number of filter channels.
    integration_window : scalar
        Integration time window in seconds (default 10 ms).
    tstep : scalar
        Time step size in the output (default is same as ``wav``).
    location : str
        Location of the output relative to the input time axis:

        - ``right``: gammatone sample at end of integration window (default)
        - ``left``: gammatone sample at beginning of integration window
        - ``center``: gammatone sample at center of integration window

        Since gammatone filter response depends on ``integration_window``, the
        filter response will be delayed relative to the analytic envlope. To
        ignore this delay, use `location='left'`
    pad : bool
        Pad output to match time axis of input.
    name : str
        NDVar name (default is ``wav.name``).

    Notes
    -----
    Requires the ``fmax`` branch of the gammatone library to be installed:

        $ pip install https://github.com/christianbrodbeck/gammatone/archive/fmax.zip
    """
    from gammatone.filters import centre_freqs, erb_filterbank
    from gammatone.gtgram import make_erb_filters

    wav_ = wav
    if location == 'left':
        if pad:
            wav_ = _pad_func(wav, wav.time.tmin - integration_window)
    elif location == 'right':
        # tmin += window_time
        if pad:
            wav_ = _pad_func(wav, tstop=wav.time.tstop + integration_window)
    elif location == 'center':
        dt = integration_window / 2
        # tmin += dt
        if pad:
            wav_ = _pad_func(wav, wav.time.tmin - dt, wav.time.tstop + dt)
    else:
        raise ValueError(f"mode={location!r}")
    fs = 1 / wav.time.tstep
    if tstep is None:
        tstep = wav.time.tstep
    wave = wav_.get_data('time')
    # based on gammatone library, rewritten to reduce memory footprint
    cfs = centre_freqs(fs, n, f_min, f_max)
    integration_window_len = int(round(integration_window * fs))
    output_n_samples = floor((len(wave) - integration_window_len) * wav.time.tstep / tstep)
    output_step = tstep / wav.time.tstep
    results = []
    for i, cf in tqdm(enumerate(reversed(cfs)), "Gammatone spectrogram", total=len(cfs), unit='band'):
        fcoefs = np.flipud(make_erb_filters(fs, cf))
        xf = erb_filterbank(wave, fcoefs)
        results.append(aggregate(xf[0], output_n_samples, output_step, integration_window_len))
    result = np.sqrt(results)
    # package output
    freq_dim = Scalar('frequency', cfs[::-1], 'Hz')
    time_dim = UTS(wav.time.tmin, tstep, output_n_samples)
    if name is None:
        name = wav.name
    return NDVar(result, (freq_dim, time_dim), name)
Example #13
0
def gammatone_bank(wav: NDVar,
                   f_min: float,
                   f_max: float,
                   n: int,
                   integration_window: float = 0.010,
                   tstep: float = None,
                   location: str = 'right',
                   pad: bool = True,
                   name: str = None) -> NDVar:
    """Gammatone filterbank response

    Parameters
    ----------
    wav : NDVar
        Sound input.
    f_min : scalar
        Lower frequency cutoff.
    f_max : scalar
        Upper frequency cutoff.
    n : int
        Number of filter channels.
    integration_window : scalar
        Integration time window in seconds (default 10 ms).
    tstep : scalar
        Time step size in the output (default is same as ``wav``).
    location : str
        Location of the output relative to the input time axis:

        - ``right``: gammatone sample at end of integration window (default)
        - ``left``: gammatone sample at beginning of integration window
        - ``center``: gammatone sample at center of integration window

        Since gammatone filter response depends on ``integration_window``, the
        filter response will be delayed relative to the analytic envlope. To
        ignore this delay, use `location='left'`
    pad : bool
        Pad output to match time axis of input.
    name : str
        NDVar name (default is ``wav.name``).

    Notes
    -----
    Requires the ``fmax`` branch of the gammatone library to be installed:

        $ pip install https://github.com/christianbrodbeck/gammatone/archive/fmax.zip
    """
    from gammatone.filters import centre_freqs
    from gammatone.gtgram import gtgram

    tmin = wav.time.tmin
    wav_ = wav
    if location == 'left':
        if pad:
            wav_ = _pad_func(wav, wav.time.tmin - integration_window)
    elif location == 'right':
        # tmin += window_time
        if pad:
            wav_ = _pad_func(wav, tstop=wav.time.tstop + integration_window)
    elif location == 'center':
        dt = integration_window / 2
        # tmin += dt
        if pad:
            wav_ = _pad_func(wav, wav.time.tmin - dt, wav.time.tstop + dt)
    else:
        raise ValueError(f"mode={location!r}")
    sfreq = 1 / wav.time.tstep
    if tstep is None:
        tstep = wav.time.tstep
    x = gtgram(wav_.get_data('time'), sfreq, integration_window, tstep, n,
               f_min, f_max)
    freqs = centre_freqs(sfreq, n, f_min, f_max)
    # freqs = np.round(freqs, out=freqs).astype(int)
    freq_dim = Scalar('frequency', freqs[::-1], 'Hz')
    time_dim = UTS(tmin, tstep, x.shape[1])
    return NDVar(x, (freq_dim, time_dim), name or wav.name)
Example #14
0
def pad(
        ndvar: NDVar,
        tstart: float = None,
        tstop: float = None,
        nsamples: int = None,
        set_tmin: bool = False,
        name: str = None,
) -> NDVar:
    """Pad (or crop) an NDVar in time

    Parameters
    ----------
    ndvar
        NDVar to pad.
    tstart
        New tstart.
    tstop
        New tstop.
    nsamples
        New number of samples.
    set_tmin
        Reset ``tmin`` to be exactly equal to ``tstart``.
    name
        Name for the new NDVar.
    """
    axis = ndvar.get_axis('time')
    time: UTS = ndvar.dims[axis]
    if name is None:
        name = ndvar.name
    # start
    if tstart is None:
        if set_tmin:
            raise ValueError("set_tmin without defining tstart")
        if nsamples is not None:
            raise NotImplementedError("nsamples without tstart")
        n_add_start = 0
    elif tstart < time.tmin:
        n_add_start = int(ceil((time.tmin - tstart) / time.tstep))
    elif tstart > time.tmin:
        n_add_start = -time._array_index(tstart)
    else:
        n_add_start = 0

    # end
    if nsamples is None and tstop is None:
        n_add_end = 0
    elif nsamples is None:
        n_add_end = int((tstop - time.tstop) // time.tstep)
    elif tstop is None:
        n_add_end = nsamples - n_add_start - time.nsamples
    else:
        raise TypeError("Can only specify one of tstart and nsamples")
    # need to pad?
    if not n_add_start and not n_add_end:
        return ndvar
    # construct padded data
    xs = [ndvar.x]
    shape = ndvar.x.shape
    # start
    if n_add_start > 0:
        shape_start = shape[:axis] + (n_add_start,) + shape[axis + 1:]
        xs.insert(0, np.zeros(shape_start))
    elif n_add_start < 0:
        xs[0] = xs[0][index(slice(-n_add_start, None), axis)]
    # end
    if n_add_end > 0:
        shape_end = shape[:axis] + (n_add_end,) + shape[axis + 1:]
        xs += (np.zeros(shape_end),)
    elif n_add_end < 0:
        xs[-1] = xs[-1][index(slice(None, n_add_end), axis)]
    x = np.concatenate(xs, axis)
    if set_tmin:
        new_tmin = tstart
    else:
        new_tmin = time.tmin - (time.tstep * n_add_start)
    new_time = UTS(new_tmin, time.tstep, x.shape[axis])
    dims = (*ndvar.dims[:axis], new_time, *ndvar.dims[axis + 1:])
    return NDVar(x, dims, name, ndvar.info)