def forward(self, input_states, ctx_states, mask_input, mask_ctx, **kwargs): """ return shape 'seq_len, batch_size, depth '""" if self.pos_init: pos_encoding = positional_encodings_like(input_states) output = input_states + pos_encoding else: pos_encoding = None output = input_states self_padding_mask = (1 - mask_input).byte() output = self.self_encoder(x=output, padding_mask=self_padding_mask) if self.pos_slf_attn is not None: if pos_encoding is None: pos_encoding = positional_encodings_like(output) output = self.pos_slf_attn(x=pos_encoding, value=output, padding_mask=self_padding_mask) inter_padding_mask = (1 - mask_ctx).byte() output = self.cond_attn(output, ctx_states, ctx_states, padding_mask=inter_padding_mask, static_kv=True) output = self.feedforward(output) return output
def extract_features(self, prev_output_tokens, encoder_out=None, **unused): """ Similar to *forward* but only return features. Returns: tuple: - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs length-predict module output position-search module output position-predict module output decoder module output """ # embed positions inputs_dict = self._bridging(encoder_out=encoder_out, prev_output_tokens=prev_output_tokens) x = inputs_dict['inputs'] if self.project_in_dim is not None: x = self.project_in_dim(x) positions = positional_encodings_like(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C # x = x.transpose(0, 1) # attn = None # inner_states = [x] # decoder layers # self_attn_masks = self._buffered_nat_mask(x) # self_attn_padding_mask = (1 - inputs_dict[LengthPredictorBridge.DECODE_MASK_KEY]).byte() # for layer in self.decoder_layers: # x, attn = layer( # x, # encoder_out=encoder_out['encoder_out'] if encoder_out is not None else None, # encoder_padding_mask=encoder_out['encoder_padding_mask'] if encoder_out is not None else None, # self_attn_mask=self_attn_masks, # self_attn_padding_mask=self_attn_padding_mask # ) # inner_states.append(x) # # if self.normalize: # x = self.layer_norm(x) # # # T x B x C -> B x T x C # x = x.transpose(0, 1) # # if self.project_out_dim is not None: # x = self.project_out_dim(x) # # inputs_dict['attn'] = attn # inputs_dict['inner_states'] = inner_states x, inputs_dict = self._decoding( x, encoder_out, inputs_dict, ) return x, inputs_dict
def forward(self, src_tokens, src_lengths, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions if self.share_embed: x = F.embedding(src_tokens, self.embed_tokens.weight * self.embed_scale) else: x = self.embed_tokens(src_tokens) * self.embed_scale # x = self.embed_scale * self.embed_tokens(src_tokens) # if self.embed_positions is not None: # x += self.embed_positions(src_tokens) x += positional_encodings_like(x) encoder_history = [x] # x = F.dropout(x, p=self.dropout, training=self.training) x = self.dropout(x) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) encoder_history.append(x.transpose(0, 1)) if self.normalize: x = self.layer_norm(x) return { 'encoder_history': encoder_history, # List<B X T X C> 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T }
def forward(self, input_states, ctx_states, mask_input, mask_ctx, **kwargs): """ return shape 'seq_len, batch_size, depth '""" output = self.encode(input_states, mask_input) self_padding_mask = (1 - mask_input).byte() if self.pos_slf_attn is not None: pos_encoding = positional_encodings_like(output) output = self.pos_slf_attn(x=pos_encoding, value=output, padding_mask=self_padding_mask) key_padding_mask = (1 - mask_ctx).byte() output = self.cond_attn(output, ctx_states, ctx_states, key_padding_mask) output = self.feedforward(output) return output
def forward(self, x, encoder_out=None, encoder_padding_mask=None, self_attn_mask=None, self_attn_padding_mask=None, pos_key=None, pos_val=None, **kwargs): pos_key, pos_val = self._get_pos_repr(x, pos_key, pos_val) x, attn = self.self_attn_block(x=x, padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, ret_attn=True, need_weights=False, pos_key=pos_key, pos_val=pos_val) if self.pos_self_attn_block is not None: pos_encoding, weights = positional_encodings_like(x), None x = self.pos_self_attn_block( x=pos_encoding, value=x, padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False, ) if self.encoder_attn_block is not None: x, attn = self.encoder_attn_block( query=x, key=encoder_out, value=encoder_out, padding_mask=encoder_padding_mask, ret_attn=True, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = self.ffn_block(x) return x, attn
def wrap_position(self, pos_query, candidate): """ return the position for matching :param pos_query: batch, query_len, state_dim :param candidate: batch, target_len, state_dim :return: """ # batch_size, target_len, state_dim if self.position_type == 0: return 0, 0 position_candidate = positional_encodings_like(candidate) if self.position_type == 1: return 0, position_candidate if self.normalize: # batch, query_len, target_len position_logits = F.cosine_similarity(pos_query[:, None, :, :], position_candidate[:, :, None, :], dim=-1) else: # batch, query_len, target_len position_logits = torch.matmul(pos_query, position_candidate.contiguous().transpose(1, 2)) # batch_size, query_len, state_dim position_query = torch.matmul(position_logits, position_candidate) return position_query, position_candidate
def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn.set_input_buffer(incremental_state, saved_state) residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) # if self.pos_self_attn is not None: # residual = x # pos_encoding, weights = positional_encodings_like(x), None # x, attn = self.pos_self_attn( # query=pos_encoding, # key=pos_encoding, # value=x, # key_padding_mask=self_attn_padding_mask, # need_weights=False, # attn_mask=self_attn_mask, # ) # x = F.dropout(x, p=self.dropout, training=self.training) # x = residual + x # x = self.maybe_layer_norm(self.pos_self_attn_layer_norm, x, after=True) if self.pos_self_attn_block is not None: pos_encoding, weights = positional_encodings_like(x), None x = self.pos_self_attn_block( x=pos_encoding, value=x, padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False, ) if self.encoder_attn is not None: if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn.set_input_buffer(incremental_state, saved_state) residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn.get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn