Ejemplo n.º 1
0
    def __iter__(self):
        _ks = []
        _xs = []
        _ts = []
        _ilens = []

        for ks, xs, ts, ilens in self.dataloader:
            assert len(ks) == 1, \
                f'batch-size of dataloder is not 1: {len(ks)} != 1'
            # Check shape:
            assert isinstance(ks[0], str), type(ks[0])
            # Expected: x: ( C, T, F), t: (C, T, F)
            assert xs[0].dim() == 3, xs[0].shape
            assert xs[0].shape == ts[0].shape, (xs[0].shape, ts[0].shape)

            offset = 0
            while True:
                ilen = ilens[0]
                if offset + self.width > ilen:
                    break

                _k = ks[0]
                _x = xs[0, :,
                        max(offset - self.lcontext, 0):offset + self.width + self.rcontext, :]
                if _x.shape[1] < self.width + self.lcontext + self.rcontext:
                    lp = max(0, self.lcontext - offset)
                    rp = max(offset + self.width + self.rcontext - xs.size(2), 0)
                    _x = FC.pad(_x, (0, 0, rp, lp, 0, 0), mode='constant')

                # _t: (C, width, F)
                _t = ts[0, :, offset:offset + self.width, :]
                _l = self.width + self.lcontext + self.rcontext

                _ks.append(_k)
                _xs.append(_x)
                _ts.append(_t)
                _ilens.append(_l)
                offset += self.width

                if len(_ks) == self.batch_size:
                    _ks = tuple(_ks)
                    # _x: (C, width, context, F)
                    yield _ks, FC.stack(_xs), \
                        FC.stack(_ts), torch.tensor(_ilens, device=_xs[0].device)

                    _ks = []
                    _xs = []
                    _ts = []
                    _ilens = []
Ejemplo n.º 2
0
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
    if not isinstance(seq, (list, tuple)):
        raise TypeError(
            "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
            "not Tensor")
    if isinstance(seq[0], ComplexTensor):
        return FC.stack(seq, *args, **kwargs)
    else:
        return torch.stack(seq, *args, **kwargs)
Ejemplo n.º 3
0
def perform_filter_operation_v2(Y: ComplexTensor,
                                filter_matrix_conj: ComplexTensor,
                                taps, delay) -> ComplexTensor:
    """perform_filter_operation_v2

    Args:
        Y : Complex-valued STFT signal of shape (F, C, T)
        filter Matrix (F, taps, C, C)
    """
    T = Y.size(-1)
    # Y_tilde: (taps, F, C, T)
    Y_tilde = FC.stack([FC.pad(Y[:, :, :T - delay - i], (delay + i, 0),
                               mode='constant', value=0)
                        for i in range(taps)],
                       dim=0)
    reverb_tail = FC.einsum('fpde,pfdt->fet', (filter_matrix_conj, Y_tilde))
    return Y - reverb_tail