Beispiel #1
0
    def _chunking(
            self,
            hs: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int], Tuple[int]]:
        """Chunking sequence into segments
        Args:
            hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
        Returns:
            Tensor: Batch of hidden state segments (B * segment_len, Chunk_size, adim)
            Tuple[Int]: Pad infomration (0, 0, pad_left, pad_right)
            Tuple[Int]: Segmentd shape information (B, segment_len, Chunk_size, adim)
        """
        batch_size, max_length, feat_dim = hs.size()

        # padding interpretation
        # [overlap_size] valid_value [overlap_size]
        # ---------------- segments ---------------
        valid_value = self.chunk_size - 2 * self.overlap_size
        if max_length % valid_value == 0:
            pad_right = self.overlap_size
        else:
            pad_right = valid_value - max_length % valid_value + self.overlap_size
        pad = (0, 0, self.overlap_size, pad_right)
        segment_len = ceil(max_length / valid_value)
        hs = torch.nn.functional.pad(hs, pad, "constant", 0)

        segmented_hs = hs.as_strided(
            (batch_size, segment_len, self.chunk_size, feat_dim),
            (max_length * feat_dim, valid_value * feat_dim, feat_dim, 1),
        )
        segmented_hs = segmented_hs.reshape(batch_size * segment_len,
                                            self.chunk_size, feat_dim)
        return segmented_hs, pad, (batch_size, segment_len, self.chunk_size,
                                   feat_dim)
Beispiel #2
0
    def forward(
        self,
        input: torch.Tensor,
        ilens: torch.Tensor = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """LabelAggregate forward function.

        Args:
            input: (Batch, Nsamples, Label_dim)
            ilens: (Batch)
        Returns:
            output: (Batch, Frames, Label_dim)

        """
        bs = input.size(0)
        max_length = input.size(1)
        label_dim = input.size(2)

        # NOTE(jiatong):
        #   The default behaviour of label aggregation is compatible with
        #   torch.stft about framing and padding.

        # Step1: center padding
        if self.center:
            pad = self.win_length // 2
            max_length = max_length + 2 * pad
            input = torch.nn.functional.pad(input, (0, 0, pad, pad),
                                            "constant", 0)
            input[:, :pad, :] = input[:, pad:(2 * pad), :]
            input[:,
                  (max_length -
                   pad):max_length, :] = input[:,
                                               (max_length -
                                                2 * pad):(max_length - pad), :]
            nframe = (max_length - self.win_length) // self.hop_length + 1

        # Step2: framing
        output = input.as_strided(
            (bs, nframe, self.win_length, label_dim),
            (max_length * label_dim, self.hop_length * label_dim, label_dim,
             1),
        )

        # Step3: aggregate label
        output = torch.gt(output.sum(dim=2, keepdim=False),
                          self.win_length // 2)
        output = output.float()

        # Step4: process lengths
        if ilens is not None:
            if self.center:
                pad = self.win_length // 2
                ilens = ilens + 2 * pad

            olens = (ilens - self.win_length) // self.hop_length + 1
            output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
        else:
            olens = None

        return output, olens
Beispiel #3
0
def extract_C(from_tns: torch.Tensor, size: Union[Tuple[int, int], Tuple[int]],
              offset: int) -> torch.Tensor:
    if len(size) == 1:
        stride = (1, )
    elif len(size) == 2:
        stride = (size[1], 1)
    else:
        raise ValueError("extract_C can only extract 1 or 2D tensors.")
    return from_tns.as_strided(size=size, stride=stride, storage_offset=offset)
Beispiel #4
0
 def forward(self, in_data: torch.Tensor) -> torch.Tensor:
     assert in_data.is_cuda
     b, d = in_data.shape
     assert d % 2 == 0
     s = in_data.stride()
     new_stride = (s[0], s[1], s[1] * d // 2)
     data = in_data.as_strided(size=(b, d // 2, 2), stride=new_stride)
     out = data.max(dim=2)[0]
     self.mask = data == out.reshape((b, d // 2, 1))
     return out
Beispiel #5
0
 def backward(self, eta: torch.Tensor) -> torch.Tensor:
     assert eta.is_cuda
     b, d2 = eta.shape
     d = d2 * 2
     s = eta.stride()
     new_stride = (s[0], s[1], 0)
     eta = eta.as_strided(size=(b, d // 2, 2), stride=new_stride)
     next_eta = eta * self.mask
     next_eta = torch.cat((next_eta[:, :, 0], next_eta[:, :, 1]), dim=1)
     assert next_eta.shape == (b, d)
     return next_eta
Beispiel #6
0
 def forward(self, in_data: torch.Tensor) -> torch.Tensor:
     assert in_data.is_cuda
     b, h, w, c = in_data.shape
     assert c % 2 == 0
     s = in_data.stride()
     new_stride = (s[0], s[1], s[2], s[3], s[3] * c // 2)
     data = in_data.as_strided(size=(b, h, w, c // 2, 2), stride=new_stride)
     out = data.max(dim=4)[0]
     # 检验
     # assert (out == torch.max(in_data[:, :, :, :c // 2], in_data[:, :, :, c // 2:])).all()
     self.mask = data == out.reshape((b, h, w, c // 2, 1))
     return out
def im2col_enhanced(im: torch.Tensor, kernel_size, stride,
                    inner_stride=(1, 1)) -> torch.Tensor:
    kh, kw = kernel_size
    sh, sw = stride
    ish, isw = inner_stride
    b, h, w, c = im.shape
    assert (h - kh * ish) % sh == 0
    assert (w - kw * isw) % sw == 0
    out_h = (h - kh * ish) // sh + 1
    out_w = (w - kw * isw) // sw + 1
    out_size = (b, out_h, out_w, kh, kw, c)
    s = im.stride()
    out_stride = (s[0], s[1] * sh, s[2] * sw, s[1] * ish, s[2] * isw, s[3])
    col_img = im.as_strided(size=out_size, stride=out_stride)
    return col_img
Beispiel #8
0
def _get_strided(waveform: Tensor, window_size: int, window_shift: int,
                 snip_edges: bool) -> Tensor:
    r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
    representing how the window is shifted along the waveform. Each row is a frame.

    Args:
        waveform (Tensor): Tensor of size ``num_samples``
        window_size (int): Frame length
        window_shift (int): Frame shift
        snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
            in the file, and the number of frames depends on the frame_length.  If False, the number of frames
            depends only on the frame_shift, and we reflect the data at the ends.

    Returns:
        Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
    """
    assert waveform.dim() == 1

    num_samples = waveform.size(0)
    strides = (window_shift * waveform.stride(0), waveform.stride(0))

    if snip_edges:
        if num_samples < window_size:
            return torch.empty((0, 0))

        else:
            m = 1 + (num_samples - window_size) // window_shift

    else:
        reversed_waveform = torch.flip(waveform, [0])
        m = (num_samples + (window_shift // 2)) // window_shift

        pad = window_size // 2 - window_shift // 2
        pad_right = reversed_waveform

        if pad > 0:
            # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
            # but we want [2, 1, 0, 0, 1, 2]
            pad_left = reversed_waveform[-pad:]
            waveform = torch.cat((pad_left, waveform, pad_right), dim=0)

        else:
            # pad is negative so we want to trim the waveform at the front
            waveform = torch.cat((waveform[-pad:], pad_right), dim=0)

    sizes = (m, window_size)
    return waveform.as_strided(sizes, strides)
Beispiel #9
0
def _get_strided_batch(waveform: torch.Tensor, window_length: int,
                       window_shift: int, snip_edges: bool) -> torch.Tensor:
    r"""Given a waveform (2D tensor of size ``(batch_size, num_samples)``,
    it returns a 2D tensor ``(batch_size, num_frames, window_length)``
    representing how the window is shifted along the waveform. Each row is a frame.
    Args:
        waveform (torch.Tensor): Tensor of size ``(batch_size, num_samples)``
        window_size (int): Frame length
        window_shift (int): Frame shift
        snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
            in the file, and the number of frames depends on the frame_length.  If False, the number of frames
            depends only on the frame_shift, and we reflect the data at the ends.
    Returns:
        torch.Tensor: 3D tensor of size (m, ``window_size``) where each row is a frame
    """
    assert waveform.dim() == 2
    batch_size = waveform.size(0)
    num_samples = waveform.size(-1)

    if snip_edges:
        if num_samples < window_length:
            return torch.empty((0, 0, 0))
        else:
            num_frames = 1 + (num_samples - window_length) // window_shift
    else:
        num_frames = (num_samples + (window_shift // 2)) // window_shift
        new_num_samples = (num_frames - 1) * window_shift + window_length
        npad = new_num_samples - num_samples
        npad_left = int((window_length - window_shift) // 2)
        npad_right = npad - npad_left
        # waveform = nn.functional.pad(waveform, (npad_left, npad_right), mode='reflect')
        pad_left = torch.flip(waveform[:, :npad_left], (1, ))
        if npad_right >= 0:
            pad_right = torch.flip(waveform[:, -npad_right:], (1, ))
        else:
            pad_right = torch.zeros(0, dtype=waveform.dtype)
        waveform = torch.cat((pad_left, waveform, pad_right), dim=1)

    strides = (
        waveform.stride(0),
        window_shift * waveform.stride(1),
        waveform.stride(1),
    )
    sizes = [batch_size, num_frames, window_length]
    return waveform.as_strided(sizes, strides)
Beispiel #10
0
    def generate_mask(
        self,
        silence: Tensor,
    ):
        """
        :param silence: bool (batch_size, length)
        :return:
            output: bool (batch_size, ?)
        """
        window_length = 1 + numpy.sum(2**numpy.arange(1, self.layer_num + 1))

        silence = silence.unsqueeze(2)
        silence = silence.as_strided(
            size=(silence.shape[0], silence.shape[1] - (window_length - 1),
                  window_length),
            stride=(1, 1, 1),
        )
        return ~(silence.all(dim=2))
Beispiel #11
0
    def rel_shift(self, x: Tensor) -> Tensor:
        """Compute relative positional encoding.

        Args:
            x: Input tensor (batch, head, time1, 2*time1-1).
                time1 means the length of query vector.

        Returns:
            Tensor: tensor of shape (batch, head, time1, time2)
          (note: time2 has the same value as time1, but it is for
          the key, while time1 is for the query).
        """
        (batch_size, num_heads, time1, n) = x.shape
        assert n == 2 * time1 - 1
        (batch_stride, head_stride, time1_stride, n_stride) = x.stride()
        return x.as_strided(
            (batch_size, num_heads, time1, time1),
            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
            storage_offset=n_stride * (time1 - 1))
Beispiel #12
0
    def rel_shift(self,
                  x: torch.Tensor,
                  left_context: int = 0) -> torch.Tensor:
        """Compute relative positional encoding.

        Args:
            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
            left_context: Number of frames in left context.

        Returns:
            x: Output sequence. (B, H, T_1, T_2)

        """
        batch_size, n_heads, time1, n = x.shape
        time2 = time1 + left_context

        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()

        return x.as_strided(
            (batch_size, n_heads, time1, time2),
            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
            storage_offset=(n_stride * (time1 - 1)),
        )
Beispiel #13
0
    def label_aggregate(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """lage_aggregate function.
        Args:
            input: (Batch, Nsamples, Label_dim)
            input_lengths: (Batch)
        Returns:
            output: (Batch, Frames, Label_dim)
        """
        bs = input.size(0)
        max_length = input.size(1)
        label_dim = input.size(2)

        # NOTE(jiatong):
        #   The default behaviour of label aggregation is compatible with
        #   torch.stft about framing and padding.

        # Step1: center padding
        if self.center:
            pad = self.win_length // 2
            max_length = max_length + 2 * pad
            input = torch.nn.functional.pad(input, (0, 0, pad, pad),
                                            "constant", 0)
            input[:, :pad, :] = input[:, pad:(2 * pad), :]
            input[:,
                  (max_length -
                   pad):max_length, :] = input[:,
                                               (max_length -
                                                2 * pad):(max_length - pad), :]
            nframe = (max_length - self.win_length) // self.hop_length + 1

        # Step2: framing
        output = input.as_strided(
            (bs, nframe, self.win_length, label_dim),
            (max_length * label_dim, self.hop_length * label_dim, label_dim,
             1),
        )

        # Step3: aggregate label
        # (bs, nframe, self.win_length, label_dim) => (bs, nframe)
        _tmp = output.sum(dim=-1, keepdim=False).float()
        output = _tmp[:, :, self.win_length // 2]

        # Step4: process lengths
        if input_lengths is not None:
            if self.center:
                pad = self.win_length // 2
                input_lengths = input_lengths + 2 * pad

            tmp_a = (input_lengths - self.win_length)
            tmp_b = self.hop_length + 1
            #             olens =  //
            olens = torch.div(tmp_a, tmp_b, rounding_mode="trunc")
            output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
        else:
            olens = None

        return output, olens
Beispiel #14
0
def _get_strided_batch_streaming(
    waveform: torch.Tensor,
    window_shift: int,
    window_length: int,
    prev_remainder: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    A variant of _get_strided_batch that creates short frames of a batch of audio signals
    in a way suitable for streaming. It accepts a waveform, window size parameters, and
    an optional buffer of previously unused samples. It returns a pair of waveform windows tensor,
    and unused part of the waveform to be passed as ``prev_remainder`` in the next call to this
    function.

    Example usage::

        >>> # get the first buffer of audio and make frames
        >>> waveform = get_incoming_audio_from_mic()
        >>> frames, remainder = _get_strided_batch_streaming(
        ...     waveform,
        ...     window_shift=160,
        ...     window_length=200,
        ... )
        >>>
        >>> process(frames)  # do sth with the frames
        >>>
        >>> # get the next buffer and use previous remainder to make frames
        >>> waveform = get_incoming_audio_from_mic()
        >>> frames, remainder = _get_strided_batch_streaming(
        ...     waveform,
        ...     window_shift=160,
        ...     window_length=200,
        ...     prev_remainder=prev_remainder,
        ... )

    .. caution:: This windowing mechanism only supports ``snip_edges=False``.

    :param waveform: A waveform tensor of shape ``(batch_size, num_samples)``.
    :param window_shift: The shift between frames measured in the number of samples.
    :param window_length: The number of samples in each window (frame).
    :param prev_remainder: An optional waveform tensor of shape ``(batch_size, num_samples)``.
        Can be ``None`` which indicates the start of a recording.
    :return: a pair of tensors with shapes ``(batch_size, num_frames, window_length)`` and
        ``(batch_size, remainder_len)``.
    """

    assert window_shift <= window_length
    assert waveform.dim() == 2
    batch_size = waveform.size(0)

    if prev_remainder is None:
        npad_left = int((window_length - window_shift) // 2)
        pad_left = torch.flip(waveform[:, :npad_left], (1, ))
        waveform = torch.cat((pad_left, waveform), dim=1)
    else:
        assert prev_remainder.dim() == 2
        assert prev_remainder.size(0) == batch_size
        waveform = torch.cat((prev_remainder, waveform), dim=1)

    num_samples = waveform.size(-1)

    window_remainder = window_length - window_shift
    num_frames = (num_samples - window_remainder) // window_shift

    remainder = waveform[:, num_frames * window_shift:]

    strides = (
        waveform.stride(0),
        window_shift * waveform.stride(1),
        waveform.stride(1),
    )

    sizes = [batch_size, num_frames, window_length]

    return waveform.as_strided(sizes, strides), remainder