def __init__(self, args): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.avg_attn = AverageAttention(self.embed_dim, dropout=args.attention_dropout) # differently than original paper, we use a single gate self.aan_gating_fc = fairseq_transformer.Linear( self.embed_dim * 2, self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu")) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.avg_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = fairseq_transformer.Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = fairseq_transformer.Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False
class AANDecoderLayer(nn.Module): """ Based on https://arxiv.org/abs/1805.00631 """ def __init__(self, args): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.avg_attn = AverageAttention(self.embed_dim, dropout=args.attention_dropout) # differently than original paper, we use a single gate self.aan_gating_fc = fairseq_transformer.Linear( self.embed_dim * 2, self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu")) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.avg_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = fairseq_transformer.Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = fairseq_transformer.Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, need_attn=False, need_head_weights=False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` The following are used for export tracing: prev_self_attn_state: [prev_sum, prev_pos] assumes AverageAttention without mask trick prev_attn_state: [prev_key, prev_value] """ if need_head_weights: need_attn = True residual = x x = self.maybe_layer_norm(self.avg_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_sum, prev_pos = prev_self_attn_state # (batch, embed) -> (seq, batch, embed) prev_sum = prev_sum.unsqueeze(0) saved_state = {"prev_sum": prev_sum, "prev_pos": prev_pos} self.avg_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.avg_attn( value=x, mask_future_timesteps=True, incremental_state=incremental_state, mask_trick=self.training, ) # differently than original paper, we use a single gate gate = torch.sigmoid( self.aan_gating_fc(torch.cat([residual, x], dim=-1))) x = gate * x + (1 - gate) * residual x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.avg_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.avg_attn._get_input_buffer(incremental_state) # remove sequence axis for export prev_sum = saved_state["prev_sum"] # (seq, batch, embed) -> (batch, embed) prev_sum = prev_sum.squeeze(0) prev_pos = saved_state["prev_pos"] self_attn_state = prev_sum, prev_pos return x, attn, self_attn_state return x, attn, None def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
def __init__(self, args, no_encoder_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.more_dropouts = args.decoder_aan_more_dropouts if args.decoder_attn_window_size <= 0: self.avg_attn = AverageAttention(self.embed_dim, dropout=args.attention_dropout) else: self.avg_attn = AverageWindowAttention( self.embed_dim, dropout=args.attention_dropout, window_size=args.decoder_attn_window_size, ) # self.activation = getattr(args, "decoder_ffn_activation", "relu") self.aan_layer_norm = LayerNorm(self.embed_dim) if args.no_decoder_aan_ffn: self.aan_ffn = None else: aan_ffn_hidden_dim = (args.decoder_ffn_embed_dim if args.decoder_aan_ffn_use_embed_dim else args.decoder_ffn_embed_dim) self.aan_ffn = FeedForwardNetwork( self.embed_dim, aan_ffn_hidden_dim, self.embed_dim, num_layers=2, dropout=args.relu_dropout, ) if args.no_decoder_aan_gating: self.aan_gating_fc = None else: self.aan_gating_fc = Linear(self.embed_dim * 2, self.embed_dim * 2) self.normalize_before = args.decoder_normalize_before if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=args.encoder_embed_dim, vdim=args.encoder_embed_dim, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.ffn = FeedForwardNetwork( self.embed_dim, args.decoder_ffn_embed_dim, self.embed_dim, num_layers=2, dropout=args.relu_dropout, ) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False