def _chunking( self, hs: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int], Tuple[int]]: """Chunking sequence into segments Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). Returns: Tensor: Batch of hidden state segments (B * segment_len, Chunk_size, adim) Tuple[Int]: Pad infomration (0, 0, pad_left, pad_right) Tuple[Int]: Segmentd shape information (B, segment_len, Chunk_size, adim) """ batch_size, max_length, feat_dim = hs.size() # padding interpretation # [overlap_size] valid_value [overlap_size] # ---------------- segments --------------- valid_value = self.chunk_size - 2 * self.overlap_size if max_length % valid_value == 0: pad_right = self.overlap_size else: pad_right = valid_value - max_length % valid_value + self.overlap_size pad = (0, 0, self.overlap_size, pad_right) segment_len = ceil(max_length / valid_value) hs = torch.nn.functional.pad(hs, pad, "constant", 0) segmented_hs = hs.as_strided( (batch_size, segment_len, self.chunk_size, feat_dim), (max_length * feat_dim, valid_value * feat_dim, feat_dim, 1), ) segmented_hs = segmented_hs.reshape(batch_size * segment_len, self.chunk_size, feat_dim) return segmented_hs, pad, (batch_size, segment_len, self.chunk_size, feat_dim)
def forward( self, input: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """LabelAggregate forward function. Args: input: (Batch, Nsamples, Label_dim) ilens: (Batch) Returns: output: (Batch, Frames, Label_dim) """ bs = input.size(0) max_length = input.size(1) label_dim = input.size(2) # NOTE(jiatong): # The default behaviour of label aggregation is compatible with # torch.stft about framing and padding. # Step1: center padding if self.center: pad = self.win_length // 2 max_length = max_length + 2 * pad input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0) input[:, :pad, :] = input[:, pad:(2 * pad), :] input[:, (max_length - pad):max_length, :] = input[:, (max_length - 2 * pad):(max_length - pad), :] nframe = (max_length - self.win_length) // self.hop_length + 1 # Step2: framing output = input.as_strided( (bs, nframe, self.win_length, label_dim), (max_length * label_dim, self.hop_length * label_dim, label_dim, 1), ) # Step3: aggregate label output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2) output = output.float() # Step4: process lengths if ilens is not None: if self.center: pad = self.win_length // 2 ilens = ilens + 2 * pad olens = (ilens - self.win_length) // self.hop_length + 1 output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) else: olens = None return output, olens
def extract_C(from_tns: torch.Tensor, size: Union[Tuple[int, int], Tuple[int]], offset: int) -> torch.Tensor: if len(size) == 1: stride = (1, ) elif len(size) == 2: stride = (size[1], 1) else: raise ValueError("extract_C can only extract 1 or 2D tensors.") return from_tns.as_strided(size=size, stride=stride, storage_offset=offset)
def forward(self, in_data: torch.Tensor) -> torch.Tensor: assert in_data.is_cuda b, d = in_data.shape assert d % 2 == 0 s = in_data.stride() new_stride = (s[0], s[1], s[1] * d // 2) data = in_data.as_strided(size=(b, d // 2, 2), stride=new_stride) out = data.max(dim=2)[0] self.mask = data == out.reshape((b, d // 2, 1)) return out
def backward(self, eta: torch.Tensor) -> torch.Tensor: assert eta.is_cuda b, d2 = eta.shape d = d2 * 2 s = eta.stride() new_stride = (s[0], s[1], 0) eta = eta.as_strided(size=(b, d // 2, 2), stride=new_stride) next_eta = eta * self.mask next_eta = torch.cat((next_eta[:, :, 0], next_eta[:, :, 1]), dim=1) assert next_eta.shape == (b, d) return next_eta
def forward(self, in_data: torch.Tensor) -> torch.Tensor: assert in_data.is_cuda b, h, w, c = in_data.shape assert c % 2 == 0 s = in_data.stride() new_stride = (s[0], s[1], s[2], s[3], s[3] * c // 2) data = in_data.as_strided(size=(b, h, w, c // 2, 2), stride=new_stride) out = data.max(dim=4)[0] # 检验 # assert (out == torch.max(in_data[:, :, :, :c // 2], in_data[:, :, :, c // 2:])).all() self.mask = data == out.reshape((b, h, w, c // 2, 1)) return out
def im2col_enhanced(im: torch.Tensor, kernel_size, stride, inner_stride=(1, 1)) -> torch.Tensor: kh, kw = kernel_size sh, sw = stride ish, isw = inner_stride b, h, w, c = im.shape assert (h - kh * ish) % sh == 0 assert (w - kw * isw) % sw == 0 out_h = (h - kh * ish) // sh + 1 out_w = (w - kw * isw) // sw + 1 out_size = (b, out_h, out_w, kh, kw, c) s = im.stride() out_stride = (s[0], s[1] * sh, s[2] * sw, s[1] * ish, s[2] * isw, s[3]) col_img = im.as_strided(size=out_size, stride=out_stride) return col_img
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) representing how the window is shifted along the waveform. Each row is a frame. Args: waveform (Tensor): Tensor of size ``num_samples`` window_size (int): Frame length window_shift (int): Frame shift snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames depends only on the frame_shift, and we reflect the data at the ends. Returns: Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame """ assert waveform.dim() == 1 num_samples = waveform.size(0) strides = (window_shift * waveform.stride(0), waveform.stride(0)) if snip_edges: if num_samples < window_size: return torch.empty((0, 0)) else: m = 1 + (num_samples - window_size) // window_shift else: reversed_waveform = torch.flip(waveform, [0]) m = (num_samples + (window_shift // 2)) // window_shift pad = window_size // 2 - window_shift // 2 pad_right = reversed_waveform if pad > 0: # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' # but we want [2, 1, 0, 0, 1, 2] pad_left = reversed_waveform[-pad:] waveform = torch.cat((pad_left, waveform, pad_right), dim=0) else: # pad is negative so we want to trim the waveform at the front waveform = torch.cat((waveform[-pad:], pad_right), dim=0) sizes = (m, window_size) return waveform.as_strided(sizes, strides)
def _get_strided_batch(waveform: torch.Tensor, window_length: int, window_shift: int, snip_edges: bool) -> torch.Tensor: r"""Given a waveform (2D tensor of size ``(batch_size, num_samples)``, it returns a 2D tensor ``(batch_size, num_frames, window_length)`` representing how the window is shifted along the waveform. Each row is a frame. Args: waveform (torch.Tensor): Tensor of size ``(batch_size, num_samples)`` window_size (int): Frame length window_shift (int): Frame shift snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames depends only on the frame_shift, and we reflect the data at the ends. Returns: torch.Tensor: 3D tensor of size (m, ``window_size``) where each row is a frame """ assert waveform.dim() == 2 batch_size = waveform.size(0) num_samples = waveform.size(-1) if snip_edges: if num_samples < window_length: return torch.empty((0, 0, 0)) else: num_frames = 1 + (num_samples - window_length) // window_shift else: num_frames = (num_samples + (window_shift // 2)) // window_shift new_num_samples = (num_frames - 1) * window_shift + window_length npad = new_num_samples - num_samples npad_left = int((window_length - window_shift) // 2) npad_right = npad - npad_left # waveform = nn.functional.pad(waveform, (npad_left, npad_right), mode='reflect') pad_left = torch.flip(waveform[:, :npad_left], (1, )) if npad_right >= 0: pad_right = torch.flip(waveform[:, -npad_right:], (1, )) else: pad_right = torch.zeros(0, dtype=waveform.dtype) waveform = torch.cat((pad_left, waveform, pad_right), dim=1) strides = ( waveform.stride(0), window_shift * waveform.stride(1), waveform.stride(1), ) sizes = [batch_size, num_frames, window_length] return waveform.as_strided(sizes, strides)
def generate_mask( self, silence: Tensor, ): """ :param silence: bool (batch_size, length) :return: output: bool (batch_size, ?) """ window_length = 1 + numpy.sum(2**numpy.arange(1, self.layer_num + 1)) silence = silence.unsqueeze(2) silence = silence.as_strided( size=(silence.shape[0], silence.shape[1] - (window_length - 1), window_length), stride=(1, 1, 1), ) return ~(silence.all(dim=2))
def rel_shift(self, x: Tensor) -> Tensor: """Compute relative positional encoding. Args: x: Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. Returns: Tensor: tensor of shape (batch, head, time1, time2) (note: time2 has the same value as time1, but it is for the key, while time1 is for the query). """ (batch_size, num_heads, time1, n) = x.shape assert n == 2 * time1 - 1 (batch_stride, head_stride, time1_stride, n_stride) = x.stride() return x.as_strided( (batch_size, num_heads, time1, time1), (batch_stride, head_stride, time1_stride - n_stride, n_stride), storage_offset=n_stride * (time1 - 1))
def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: """Compute relative positional encoding. Args: x: Input sequence. (B, H, T_1, 2 * T_1 - 1) left_context: Number of frames in left context. Returns: x: Output sequence. (B, H, T_1, T_2) """ batch_size, n_heads, time1, n = x.shape time2 = time1 + left_context batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() return x.as_strided( (batch_size, n_heads, time1, time2), (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), storage_offset=(n_stride * (time1 - 1)), )
def label_aggregate( self, input: torch.Tensor, input_lengths: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """lage_aggregate function. Args: input: (Batch, Nsamples, Label_dim) input_lengths: (Batch) Returns: output: (Batch, Frames, Label_dim) """ bs = input.size(0) max_length = input.size(1) label_dim = input.size(2) # NOTE(jiatong): # The default behaviour of label aggregation is compatible with # torch.stft about framing and padding. # Step1: center padding if self.center: pad = self.win_length // 2 max_length = max_length + 2 * pad input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0) input[:, :pad, :] = input[:, pad:(2 * pad), :] input[:, (max_length - pad):max_length, :] = input[:, (max_length - 2 * pad):(max_length - pad), :] nframe = (max_length - self.win_length) // self.hop_length + 1 # Step2: framing output = input.as_strided( (bs, nframe, self.win_length, label_dim), (max_length * label_dim, self.hop_length * label_dim, label_dim, 1), ) # Step3: aggregate label # (bs, nframe, self.win_length, label_dim) => (bs, nframe) _tmp = output.sum(dim=-1, keepdim=False).float() output = _tmp[:, :, self.win_length // 2] # Step4: process lengths if input_lengths is not None: if self.center: pad = self.win_length // 2 input_lengths = input_lengths + 2 * pad tmp_a = (input_lengths - self.win_length) tmp_b = self.hop_length + 1 # olens = // olens = torch.div(tmp_a, tmp_b, rounding_mode="trunc") output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) else: olens = None return output, olens
def _get_strided_batch_streaming( waveform: torch.Tensor, window_shift: int, window_length: int, prev_remainder: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ A variant of _get_strided_batch that creates short frames of a batch of audio signals in a way suitable for streaming. It accepts a waveform, window size parameters, and an optional buffer of previously unused samples. It returns a pair of waveform windows tensor, and unused part of the waveform to be passed as ``prev_remainder`` in the next call to this function. Example usage:: >>> # get the first buffer of audio and make frames >>> waveform = get_incoming_audio_from_mic() >>> frames, remainder = _get_strided_batch_streaming( ... waveform, ... window_shift=160, ... window_length=200, ... ) >>> >>> process(frames) # do sth with the frames >>> >>> # get the next buffer and use previous remainder to make frames >>> waveform = get_incoming_audio_from_mic() >>> frames, remainder = _get_strided_batch_streaming( ... waveform, ... window_shift=160, ... window_length=200, ... prev_remainder=prev_remainder, ... ) .. caution:: This windowing mechanism only supports ``snip_edges=False``. :param waveform: A waveform tensor of shape ``(batch_size, num_samples)``. :param window_shift: The shift between frames measured in the number of samples. :param window_length: The number of samples in each window (frame). :param prev_remainder: An optional waveform tensor of shape ``(batch_size, num_samples)``. Can be ``None`` which indicates the start of a recording. :return: a pair of tensors with shapes ``(batch_size, num_frames, window_length)`` and ``(batch_size, remainder_len)``. """ assert window_shift <= window_length assert waveform.dim() == 2 batch_size = waveform.size(0) if prev_remainder is None: npad_left = int((window_length - window_shift) // 2) pad_left = torch.flip(waveform[:, :npad_left], (1, )) waveform = torch.cat((pad_left, waveform), dim=1) else: assert prev_remainder.dim() == 2 assert prev_remainder.size(0) == batch_size waveform = torch.cat((prev_remainder, waveform), dim=1) num_samples = waveform.size(-1) window_remainder = window_length - window_shift num_frames = (num_samples - window_remainder) // window_shift remainder = waveform[:, num_frames * window_shift:] strides = ( waveform.stride(0), window_shift * waveform.stride(1), waveform.stride(1), ) sizes = [batch_size, num_frames, window_length] return waveform.as_strided(sizes, strides), remainder