def __init__(self, args, no_encoder_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim if args.max_relative_length == -1: self.self_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) else: self.self_attn = RelativeMultiheadAttention( self.embed_dim, args.decoder_attention_heads, args.max_relative_length, dropout=args.attention_dropout, k_only=args.k_only, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.context_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.context_attn_layer_norm = LayerNorm(self.embed_dim) self.gate_linear = Linear(2 * self.embed_dim, 1) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False
def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim if args.max_relative_length == -1: self.self_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, ) else: self.self_attn = RelativeMultiheadAttention( self.embed_dim, args.encoder_attention_heads, args.max_relative_length, dropout=args.attention_dropout, k_only=args.k_only, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.layer_norms = nn.ModuleList( [LayerNorm(self.embed_dim) for i in range(2)])
def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim if args.max_relative_length == -1: self.self_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, ) else: self.self_attn = RelativeMultiheadAttention( self.embed_dim, args.encoder_attention_heads, args.max_relative_length, dropout=args.attention_dropout, ) self.context_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, ) self.context_attn_layer_norm = LayerNorm(self.embed_dim) self.w1 = Linear(self.embed_dim, self.embed_dim) self.w2 = Linear(self.embed_dim, self.embed_dim, bias=False) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
class connEnDeTransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim if args.max_relative_length == -1: self.self_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) else: self.self_attn = RelativeMultiheadAttention( self.embed_dim, args.decoder_attention_heads, args.max_relative_length, dropout=args.attention_dropout, k_only=args.k_only, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) attn = None if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=True, #need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = F.relu(self.fc1(x)) x = F.dropout(x, p=self.relu_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn