def forward(self, targets: Tensor, input_lengths: Optional[Any] = None, memory: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]: self_attns, memory_attns = list(), list() non_pad_mask = get_pad_mask(targets, pad_id=self.pad_id).eq(False) self_attn_mask = get_attn_pad_mask( targets, self.pad_id) | get_subsequent_mask(targets) memory_mask = get_pad_mask( memory, input_lengths).squeeze(-1).unsqueeze(1).expand( -1, targets.size(1), -1) output = self.input_dropout( self.embedding(targets) * self.logit_scale + self.positional_encoding(targets.size(1))) for layer in self.layers: output, self_attn, memory_attn = layer(output, memory, non_pad_mask, self_attn_mask, memory_mask) self_attns.append(self_attn) memory_attns.append(memory_attn) return output, self_attns, memory_attns
def forward(self, inputs: Tensor, input_lengths: Tensor = None) -> Tuple[Tensor, Tensor]: self_attns = list() non_pad_mask = get_pad_mask(inputs, input_lengths=input_lengths).eq(False) length = inputs.size(1) self_attn_mask = get_pad_mask( inputs, input_lengths).squeeze(-1).unsqueeze(1).expand(-1, length, -1) output = self.input_dropout( self.input_layer_norm(self.input_proj(inputs)) + self.positional_encoding(inputs.size(1))) for layer in self.layers: output, attn = layer(output, non_pad_mask, self_attn_mask) self_attns.append(attn) return output, self_attns
def forward(self, inputs: Tensor, input_lengths: Tensor = None): """ Args: inputs: BxT_inputxD input_lengths: Bx1 """ self_attns = list() non_pad_mask = get_pad_mask(inputs, input_lengths=input_lengths).eq(False) self_attn_mask = get_attn_pad_mask(inputs, input_lengths, inputs.size(1)) output = self.input_dropout( self.input_norm(self.input_proj(inputs)) + self.positional_encoding(inputs.size(1))) for layer in self.layers: output, attn = layer(output, non_pad_mask, self_attn_mask) self_attns.append(attn) return output, self_attns
def forward(self, inputs: Tensor, input_lengths: Optional[Any] = None, memory: Tensor = None): self_attns, memory_attns = list(), list() batch_size, output_length = inputs.size(0), inputs.size(1) non_pad_mask = get_pad_mask(inputs, pad_id=self.pad_id).eq(False) self_attn_mask = get_decoder_self_attn_mask(inputs, inputs, self.pad_id) memory_mask = get_attn_pad_mask(memory, input_lengths, output_length) output = self.input_dropout( self.embedding(inputs) + self.positional_encoding(inputs.size(1))) for layer in self.layers: output, self_attn, memory_attn = layer(output, memory, non_pad_mask, self_attn_mask, memory_mask) self_attns.append(self_attn) memory_attns.append(memory_attn) return output, self_attns, memory_attns