def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, left_mask: int = -1, right_mask: int = -1, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Embed positions in tensor. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ # pad mask masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) # (B, 1, L) # attention mask todo:check attention mask if right_mask >= 0 or left_mask >= 0: attention_mask = ~make_attention_mask(xs_pad, left_mask, right_mask)[None, :, :] # (L, L) masks = masks & attention_mask # embed if ( isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8) ): short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) if short_status: raise TooShortUttError( f"has {xs_pad.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", xs_pad.size(1), limit_size, ) xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) # encoders xs_pad, masks = self.encoders(xs_pad, masks) # todo: my change, from conformer_encoder.py if isinstance(xs_pad, tuple): xs_pad = xs_pad[0] # (xs_pad, pos_emb) if self.normalize_before: xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) return xs_pad, olens, None
def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Calculate forward propagation. Args: xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). ilens (torch.Tensor): Input length (#batch). prev_states (torch.Tensor): Not to be used now. Returns: torch.Tensor: Output tensor (#batch, L, output_size). torch.Tensor: Output length (#batch). torch.Tensor: Not to be used now. """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) if (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) if short_status: raise TooShortUttError( f"has {xs_pad.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", xs_pad.size(1), limit_size, ) xs_pad, masks = self.embed(xs_pad, masks) elif self.embed is not None: xs_pad = self.embed(xs_pad) xs_pad, masks = self.encoders(xs_pad, masks) if isinstance(xs_pad, tuple): xs_pad = xs_pad[0] xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) return xs_pad, olens, None
def forward_one_step(self, xs_pad: torch.Tensor, left_mask: int = -1, right_mask: int = -1, prev_states: torch.Tensor = None,): # attention mask todo:check attention mask if right_mask >= 0 or left_mask >= 0: attention_mask = ~make_attention_mask(xs_pad, left_mask, right_mask)[None, :, :] # (1, L, L) masks = attention_mask else: masks = None # embed if ( isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8) ): short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) if short_status: raise TooShortUttError( f"has {xs_pad.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", xs_pad.size(1), limit_size, ) xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) # encoders xs_pad, masks = self.encoders(xs_pad, masks) # todo: my change, from conformer_encoder.py if isinstance(xs_pad, tuple): xs_pad = xs_pad[0] # (xs_pad, pos_emb) if self.normalize_before: xs_pad = self.after_norm(xs_pad) if masks is not None: olens = masks.squeeze(1).sum(1) else: olens = None return xs_pad, olens, None
def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Embed positions in tensor. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) if (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) if short_status: raise TooShortUttError( f"has {xs_pad.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", xs_pad.size(1), limit_size, ) xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) xs_pad, masks = self.encoders(xs_pad, masks) if self.normalize_before: xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) return xs_pad, olens, None
def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ctc: CTC = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Calculate forward propagation. Args: xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). ilens (torch.Tensor): Input length (#batch). prev_states (torch.Tensor): Not to be used now. Returns: torch.Tensor: Output tensor (#batch, L, output_size). torch.Tensor: Output length (#batch). torch.Tensor: Not to be used now. """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) if ( isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8) ): short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) if short_status: raise TooShortUttError( f"has {xs_pad.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", xs_pad.size(1), limit_size, ) xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) intermediate_outs = [] if len(self.interctc_layer_idx) == 0: xs_pad, masks = self.encoders(xs_pad, masks) else: for layer_idx, encoder_layer in enumerate(self.encoders): xs_pad, masks = encoder_layer(xs_pad, masks) if layer_idx + 1 in self.interctc_layer_idx: encoder_out = xs_pad if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] # intermediate outputs are also normalized if self.normalize_before: encoder_out = self.after_norm(encoder_out) intermediate_outs.append((layer_idx + 1, encoder_out)) if self.interctc_use_conditioning: ctc_out = ctc.softmax(encoder_out) if isinstance(xs_pad, tuple): x, pos_emb = xs_pad x = x + self.conditioning_layer(ctc_out) xs_pad = (x, pos_emb) else: xs_pad = xs_pad + self.conditioning_layer(ctc_out) if isinstance(xs_pad, tuple): xs_pad = xs_pad[0] if self.normalize_before: xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) if len(intermediate_outs) > 0: return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None