class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, input): """ Args: input (Tuple): input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` input[2] (ByteTensor/FloatTensor): encoder padding mask - binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: output (Tuple): output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` output[1] (ByteTensor/FloatTensor): encoder padding mask output[2] (LongTensor): previous decoder outputs """ # Note: incremental state is not yet supported mt_task = False if isinstance(input, tuple): x = input[0] encoder_out = input[1] encoder_padding_mask = input[2] incremental_state = None mt_task = True else: x = input encoder_out = None encoder_padding_mask = None incremental_state = None if incremental_state is None: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None # TODO: add back prev_self_attn_state, prev_attn_state, # self_attn_padding_mask prev_self_attn_state = None prev_attn_state = None self_attn_padding_mask = None residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if mt_task: return (x, encoder_out, encoder_padding_mask) return x def buffered_future_mask(self, tensor): dim = tensor.size(0) if ( not hasattr(self, "_future_mask") or self._future_mask is None or self._future_mask.device != tensor.device ): self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 ) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 ) return self._future_mask[:dim, :dim] def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim if args.max_relative_length == -1: self.self_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) else: self.self_attn = RelativeMultiheadAttention( self.embed_dim, args.decoder_attention_heads, args.max_relative_length, dropout=args.attention_dropout, k_only=args.k_only, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.decoder_position_dropout = args.decoder_position_dropout self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) attn = None if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = F.relu(self.fc1(x)) x = F.dropout(x, p=self.relu_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def position_dropout(self, x): if self.training and self.decoder_position_dropout != 0: position_mask = (torch.rand(x.size(0)) > self.decoder_position_dropout).view( -1, 1, 1).cuda().half() x = x * position_mask return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class LightConvDecoderLayer(nn.Module): """Decoder layer block. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs. Default: ``False`` kernel_size: kernel size of the convolution """ def __init__(self, args, no_encoder_attn=False, kernel_size=0): super().__init__() self.embed_dim = args.decoder_embed_dim self.conv_dim = args.decoder_conv_dim if args.decoder_glu: self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim) self.act = nn.GLU() else: self.linear1 = Linear(self.embed_dim, self.conv_dim) self.act = None if args.decoder_conv_type == "lightweight": self.conv = LightweightConv( self.conv_dim, kernel_size, padding_l=kernel_size - 1, weight_softmax=args.weight_softmax, num_heads=args.decoder_attention_heads, weight_dropout=args.weight_dropout, ) elif args.decoder_conv_type == "dynamic": self.conv = DynamicConv( self.conv_dim, kernel_size, padding_l=kernel_size - 1, weight_softmax=args.weight_softmax, num_heads=args.decoder_attention_heads, weight_dropout=args.weight_dropout, ) else: raise NotImplementedError self.linear2 = Linear(self.conv_dim, self.embed_dim) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__) self.relu_dropout_module = FairseqDropout( args.relu_dropout, module_name=self.__class__.__name__) self.input_dropout_module = FairseqDropout( args.input_dropout, module_name=self.__class__.__name__) self.normalize_before = args.decoder_normalize_before self.conv_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True def forward( self, x, encoder_out, encoder_padding_mask, incremental_state, prev_conv_state=None, prev_attn_state=None, conv_mask=None, conv_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.conv_layer_norm, x, before=True) if prev_conv_state is not None: if incremental_state is None: incremental_state = {} self.conv._set_input_buffer(incremental_state, prev_conv_state) x = self.input_dropout_module(x) x = self.linear1(x) if self.act is not None: x = self.act(x) x = self.conv(x, incremental_state=incremental_state) x = self.linear2(x) x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(self.conv_layer_norm, x, after=True) attn = None if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = F.relu(self.fc1(x)) x = self.relu_dropout_module(x) x = self.fc2(x) x = self.dropout_module(x) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn def extra_repr(self): return ( "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}" .format( self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before, ))
class LocalTransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False, num_layer=0): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False self.kernel_size = args.kernel_size self.padding_idx = 1 self.use_local_decoder = args.use_local_decoder if type(self.kernel_size) == list: self.kernel_size = self.kernel_size[num_layer] def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ ############################# ADDED PART #################################### #For self attention if self.use_local_decoder: tgt_len, batch_size, embed_dim = x.size() size_to_add = self.kernel_size - tgt_len % self.kernel_size x2 = torch.zeros(tgt_len+size_to_add, batch_size, embed_dim, dtype=x.dtype, \ device=x.device) x2[:tgt_len, :batch_size, :] = x x = x2.view(self.kernel_size, -1, embed_dim) if not self_attn_padding_mask: self_attn_padding_mask = torch.zeros(batch_size, tgt_len, dtype=torch.uint8, device=x.device) self_attn_padding_mask2 = torch.zeros(batch_size, tgt_len+size_to_add, dtype=encoder_padding_mask.dtype, device=encoder_padding_mask.device) self_attn_padding_mask2.fill_(1) self_attn_padding_mask2[:, :tgt_len] = self_attn_padding_mask self_attn_padding_mask = self_attn_padding_mask2.view(-1, self.kernel_size) ############################# END ADDED PART ################################### residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) ############################# MODIFIED PART #################################### current_attn_mask = self_attn_mask if self.use_local_decoder: current_attn_mask = self.buffered_future_mask(x) if incremental_state is None else None x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=current_attn_mask, ) ############################# END MODIFIED PART #################################### x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) ############################# ADDED PART #################################### if self.use_local_decoder: x2 = x.view(-1, batch_size, self.embed_dim) x = x2[:tgt_len, :, :] ############################# END ADDED PART #################################### attn = None if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = F.relu(self.fc1(x)) x = F.dropout(x, p=self.relu_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state["prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn #Normally not here, only in LocalTransformerDecoder def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim]
class TransformerDecoderLayer(nn.Module): def __init__( self, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = 'relu', add_bias_kv: bool = False, add_zero_attn: bool = False, export: bool = False, ): super().__init__() self.embedding_dim = embedding_dim self.dropout = dropout self.activation_dropout = activation_dropout # Initialize blocks self.activation_fn = utils.get_activation_fn(activation_fn) self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True) # layer norm associated with the self attention layer self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) self.encoder_attn = MultiheadAttention( self.embedding_dim, num_attention_heads, kdim=embedding_dim, vdim=embedding_dim, dropout=attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) # layer norm associated with the position wise feed-forward NN self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) self.need_attn = False def forward( self, x, encoder_out=None, encoder_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): residual = x if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) residual = x if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.encoder_attn_layer_norm(x) residual = x x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.final_layer_norm(x) return x, attn def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class transformer_with_copyDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu')) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, 'relu_dropout', 0) self.normalize_before = args.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, self_attn_pattern=None, encoder_attn_pattern=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() if self_attn_pattern is not None and args.PRUNE_DEC_SELF_ATTN: cpu_np_pattern = self_attn_pattern.cpu().numpy() d1_bounds, d2_bounds = find_bounds(cpu_np_pattern) prune_random = False try: if args.RANDOM_PRUNE: #random prune prune_random = True self.self_attn_mask = torch.from_numpy( random_mask(cpu_np_pattern, args.TAU)) if args.CUDA: self.self_attn_mask = self.self_attn_mask.cuda() except: pass if not prune_random: cpu_np_pattern = cpu_np_pattern[:, :, 0:d1_bounds, 0:d2_bounds] target_percentile = args.TAU * 100 threshold = np.percentile(cpu_np_pattern, target_percentile, interpolation='nearest') self.self_attn_mask = (self_attn_pattern <= threshold) else: self.self_attn_mask = None if encoder_attn_pattern is not None and args.PRUNE_ENC_DEC_ATTN: cpu_np_pattern = encoder_attn_pattern.cpu().numpy() d1_bounds, d2_bounds = find_bounds(cpu_np_pattern) prune_random = False try: if args.RANDOM_PRUNE: #random prune prune_random = True self.encoder_attn_mask = torch.from_numpy( random_mask(cpu_np_pattern, args.TAU)) if args.CUDA: self.encoder_attn_mask = self.encoder_attn_mask.cuda() except: pass if not prune_random: cpu_np_pattern = cpu_np_pattern[:, :, 0:d1_bounds, 0:d2_bounds] target_percentile = args.TAU * 100 threshold = np.percentile(cpu_np_pattern, target_percentile, interpolation='nearest') self.encoder_attn_mask = (encoder_attn_pattern <= threshold) else: self.encoder_attn_mask = None self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, args=args, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu")) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, args=args, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer( incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat((x.new_zeros( x.size(0), encoder_out.size(0)), self_attn_mask), dim=1) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0)) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, prune_attn_mask=self.self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn(query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, prune_attn_mask=self.encoder_attn_mask) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"] ] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn @torch.jit.export def reorder_incremental_state( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor, ): """Scriptable reorder incremental state in transformer layers.""" self.self_attn.reorder_incremental_state(incremental_state, new_order) if self.encoder_attn is not None: self.encoder_attn.reorder_incremental_state( incremental_state, new_order)
class HybridRNNDecoder(FairseqIncrementalDecoder): """ Decoder with general structure of Chen et al., The Best of Both Worlds: Combining Recent Advances in Neural Machine Translation, 2018. https://arxiv.org/abs/1804.09849 """ def _init_dims(self, args, src_dict, dst_dict, embed_tokens): self.dropout = args.dropout embed_dim = embed_tokens.embedding_dim self.embed_tokens = embed_tokens self.lstm_units = args.decoder_lstm_units self.num_layers = args.decoder_layers self.initial_input_dim = embed_dim self.encoder_output_dim = args.encoder_embed_dim if args.decoder_reduced_attention_dim is None: self.attention_dim = self.encoder_output_dim else: self.attention_dim = args.decoder_reduced_attention_dim self.input_dim = self.lstm_units + self.attention_dim self.num_attention_heads = args.decoder_attention_heads self.bottleneck_dim = args.decoder_out_embed_dim def _init_components(self, args, src_dict, dst_dict, embed_tokens): self.initial_rnn_layer = nn.LSTM(input_size=self.initial_input_dim, hidden_size=self.lstm_units) self.proj_encoder_layer = None if self.attention_dim != self.encoder_output_dim: self.proj_encoder_layer = fairseq_transformer.Linear( self.encoder_output_dim, self.attention_dim) self.proj_layer = None if self.lstm_units != self.attention_dim: self.proj_layer = fairseq_transformer.Linear( self.lstm_units, self.attention_dim) self.attention = MultiheadAttention( self.attention_dim, self.num_attention_heads, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.extra_rnn_layers = nn.ModuleList([]) for _ in range(self.num_layers - 1): self.extra_rnn_layers.append( nn.LSTM(input_size=self.input_dim, hidden_size=self.lstm_units)) self.bottleneck_layer = None if self.bottleneck_dim is not None: self.out_embed_dim = self.bottleneck_dim self.bottleneck_layer = fairseq_transformer.Linear( self.input_dim, self.out_embed_dim) else: self.out_embed_dim = self.input_dim self.embed_out = nn.Parameter( torch.Tensor(len(dst_dict), self.out_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.out_embed_dim**-0.5) self.vocab_reduction_module = None if args.vocab_reduction_params: self.vocab_reduction_module = vocab_reduction.VocabReduction( src_dict, dst_dict, args.vocab_reduction_params, fp16=args.fp16) self.onnx_trace = False def __init__(self, args, src_dict, dst_dict, embed_tokens): super().__init__(dst_dict) self._init_dims(args, src_dict, dst_dict, embed_tokens) self._init_components(args, src_dict, dst_dict, embed_tokens) # Enable dependency injection by subclasses def _unpack_encoder_out(self, encoder_out): """Allow taking encoder_out from different architecture which may have different formats. """ return encoder_out def _init_hidden(self, encoder_out, batch_size): """ Initialize with latent code if available otherwise zeros.""" return torch.zeros([1, batch_size, self.lstm_units]) def _concat_latent_code(self, x, encoder_out): """Concat latent code, if available in encoder_out, which is the case in subclass. """ return x def prepare_for_onnx_export_(self): self.onnx_trace = True def _embed_prev_outputs(self, prev_output_tokens, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) return x, prev_output_tokens def forward( self, prev_output_tokens, encoder_out, incremental_state=None, possible_translation_tokens=None, timestep=None, ): x, prev_output_tokens = self._embed_prev_outputs( prev_output_tokens=prev_output_tokens, incremental_state=incremental_state) return self._forward_given_embeddings( embed_out=x, prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, possible_translation_tokens=possible_translation_tokens, timestep=timestep, ) def _forward_given_embeddings( self, embed_out, prev_output_tokens, encoder_out, incremental_state=None, possible_translation_tokens=None, timestep=None, ): x = embed_out (encoder_x, src_tokens, encoder_padding_mask) = self._unpack_encoder_out(encoder_out) bsz, seqlen = prev_output_tokens.size() state_outputs = [] if incremental_state is not None: prev_states = utils.get_incremental_state(self, incremental_state, "cached_state") if prev_states is None: prev_states = self._init_prev_states(encoder_out) # final 2 states of list are projected key and value saved_state = { "prev_key": prev_states[-2], "prev_value": prev_states[-1] } self.attention._set_input_buffer(incremental_state, saved_state) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[0] c_prev = prev_states[1] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x = self._concat_latent_code(x, encoder_out) x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev)) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) if self.proj_encoder_layer is not None: encoder_x = self.proj_encoder_layer(encoder_x) attention_in = x if self.proj_layer is not None: attention_in = self.proj_layer(x) attention_out, attention_weights = self.attention( query=attention_in, key=encoder_x, value=encoder_x, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training), ) for i, layer in enumerate(self.extra_rnn_layers): residual = x rnn_input = torch.cat([x, attention_out], dim=2) rnn_input = self._concat_latent_code(rnn_input, encoder_out) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[2 * i + 2] c_prev = prev_states[2 * i + 3] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev)) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) x = x + residual x = torch.cat([x, attention_out], dim=2) x = self._concat_latent_code(x, encoder_out) if self.bottleneck_layer is not None: x = self.bottleneck_layer(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if (self.vocab_reduction_module is not None and possible_translation_tokens is None): decoder_input_tokens = prev_output_tokens.contiguous() possible_translation_tokens = self.vocab_reduction_module( src_tokens, decoder_input_tokens=decoder_input_tokens) output_weights = self.embed_out if possible_translation_tokens is not None: output_weights = output_weights.index_select( dim=0, index=possible_translation_tokens) logits = F.linear(x, output_weights) if incremental_state is not None: # encoder projections can be reused at each incremental step state_outputs.extend([prev_states[-2], prev_states[-1]]) utils.set_incremental_state(self, incremental_state, "cached_state", state_outputs) return logits, attention_weights, possible_translation_tokens def max_positions(self): """Maximum output length supported by the decoder.""" return int(1e5) # an arbitrary large number def _init_prev_states(self, encoder_out): """ Initial (hidden, cell) values for LSTM layers are zero. For encoder-decoder attention, key and value are computed once from the encoder outputs and stay the same throughout decoding. """ (encoder_x, src_tokens, encoder_padding_mask) = self._unpack_encoder_out(encoder_out) batch_size = torch.onnx.operators.shape_as_tensor(encoder_x)[1] if self.proj_encoder_layer is not None: encoder_x = self.proj_encoder_layer(encoder_x) states = [] for _ in range(self.num_layers): hidden = self._init_hidden(encoder_out, batch_size).type_as(encoder_x) cell = torch.zeros([1, batch_size, self.lstm_units]).type_as(encoder_x) states.extend([hidden, cell]) # (key, value) for encoder-decoder attention computed from encoder # output and remain the same throughout decoding key = self.attention.k_proj(encoder_x) value = self.attention.v_proj(encoder_x) # (key, value) kept in shape (bsz, num_heads, seq_len, head_dim) # to avoid repeated transpose operations seq_len, batch_size_int, _ = encoder_x.shape num_heads = self.attention.num_heads head_dim = self.attention.head_dim key = (key.view(seq_len, batch_size_int * num_heads, head_dim).transpose(0, 1).view(batch_size_int, num_heads, seq_len, head_dim)) value = (value.view(seq_len, batch_size_int * num_heads, head_dim).transpose(0, 1).view( batch_size_int, num_heads, seq_len, head_dim)) states.extend([key, value]) return states def reorder_incremental_state(self, incremental_state, new_order): # parent reorders attention model super().reorder_incremental_state(incremental_state, new_order) cached_state = utils.get_incremental_state(self, incremental_state, "cached_state") if cached_state is None: return # Last 2 elements of prev_states are encoder projections # used for ONNX export for i, state in enumerate(cached_state[:-2]): cached_state[i] = state.index_select(1, new_order) utils.set_incremental_state(self, incremental_state, "cached_state", cached_state)
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, index, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim kernel_size = args.decoder_kernel_size_list[index] self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu') ) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, 'relu_dropout', 0) self.normalize_before = args.decoder_normalize_before if args.decoder_branch_type is None: self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True, ) else: layers = [] embed_dims = [] heads = [] num_types = len(args.decoder_branch_type) for layer_type in args.decoder_branch_type: embed_dims.append(int(layer_type.split(':')[2])) heads.append(int(layer_type.split(':')[3])) layers.append(self.get_layer(args, index, embed_dims[-1], heads[-1], layer_type, add_bias_kv, add_zero_attn)) assert sum(embed_dims) == self.embed_dim, (sum(embed_dims), self.embed_dim) self.self_attn = MultiBranch(layers, embed_dims) # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, 'char_inputs', False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim, init=args.ffn_init) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim, init=args.ffn_init) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def get_layer(self, args, index, out_dim, num_heads, layer_type, add_bias_kv, add_zero_attn): kernel_size = layer_type.split(':')[1] if kernel_size == 'default': kernel_size = args.decoder_kernel_size_list[index] else: kernel_size = int(kernel_size) layer_type = layer_type.split(':')[0] if layer_type == 'lightweight': layer = LightweightConv( out_dim, kernel_size, padding_l=kernel_size-1, weight_softmax=args.weight_softmax, num_heads=num_heads, weight_dropout=args.weight_dropout, with_linear=args.conv_linear, ) elif layer_type == 'dynamic': layer = DynamicConv( out_dim, kernel_size, padding_l=kernel_size-1, weight_softmax=args.weight_softmax, num_heads=num_heads, weight_dropout=args.weight_dropout, with_linear=args.conv_linear, glu=args.decoder_glu, ) elif layer_type == 'attn': layer = MultiheadAttention( embed_dim=out_dim, num_heads=num_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True, ) else: raise NotImplementedError return layer def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state["prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu")) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, encoder_out2=None, balance_weight=None, encoder_padding_mask2=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ #print(encoder_out2) #print(balance_weight) attn2 = None if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: print("prev_self_attn_state not None") prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer( incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat((x.new_zeros( x.size(0), encoder_out.size(0)), self_attn_mask), dim=1) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0)) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) #print("Here cross self") else: #print("Here not cross self") y = x #print("self_attn") #print('input x', x) x, attn, _ = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) #print("end self_attn") x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) #print('output x', x) if self.encoder_attn is not None: #print("Not None") residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: #print("prev_attn_state not None") prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) #TODO: CJA #print("------", encoder_out.shape, x.shape) #print('input', x) #print('encoder_out', encoder_out) if encoder_out2 is not None: if balance_weight is not None: #print("need_head_weights", need_head_weights) #print("Incremental_state: ",incremental_state ) if need_head_weights: x, attn, attn2 = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key2=encoder_out2, value2=encoder_out2, balance_weight=balance_weight, key_padding_mask=encoder_padding_mask, key_padding_mask2=encoder_padding_mask2, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) #print('output', x) else: #print('here') #print('incremental_state', incremental_state) x, attn, _ = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key2=encoder_out2, value2=encoder_out2, balance_weight=balance_weight, key_padding_mask=encoder_padding_mask, key_padding_mask2=encoder_padding_mask2, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) else: #print("Incremental_state: ",incremental_state ) #print(encoder_out.shape) #print(encoder_out2.shape) concat_out = torch.cat([encoder_out, encoder_out2], dim=0) #print(".......") #print(encoder_out.shape, encoder_out2.shape) #print(concat_out.shape) #print("1: ", encoder_padding_mask.shape) #print("2: ",encoder_padding_mask2.shape) encoder_padding = torch.cat( [encoder_padding_mask, encoder_padding_mask2], dim=1) #print(encoder_padding_mask[0], encoder_padding_mask2[0]) #print(encoder_padding.shape) x, attn, _ = self.encoder_attn( query=x, key=concat_out, value=concat_out, key_padding_mask=encoder_padding, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) else: x, attn, _ = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) #print("!!!!!!!", x.shape) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: #print("here") self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"] ] return x, attn, self_attn_state if attn2 is not None: return x, attn, attn2 else: return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn
class MaskDecoderLayer(nn.Module): def __init__(self, args, no_encoder_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.source_encoder_attn = None self.mask_encoder_attn = None self.encoder_attn_layer_norm = None self.concat_dense = None else: self.source_encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.mask_encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.concat_dense = Linear(2 * self.embed_dim, self.embed_dim, bias=True) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, x, source_encoder_out, source_encoder_padding_mask, mask_encoder_out, mask_encoder_padding_mask, incremental_state, prev_self_attn_state=None, prev_source_attn_state=None, prev_mask_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None): residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) attn_source = None attn_mask = None if self.source_encoder_attn is not None: residual = x source_x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) mask_x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) self.set_attention_input_buffer(self.source_encoder_attn, incremental_state, prev_source_attn_state) self.set_attention_input_buffer(self.mask_encoder_attn, incremental_state, prev_mask_attn_state) source_x, attn_source = self.source_encoder_attn( query=source_x, key=source_encoder_out, value=source_encoder_out, key_padding_mask=source_encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) mask_x, attn_mask = self.mask_encoder_attn( query=mask_x, key=mask_encoder_out, value=mask_encoder_out, key_padding_mask=mask_encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = torch.cat([source_x, mask_x], dim=-1) x = F.dropout(x, p=self.dropout, training=self.training) x = F.relu(self.concat_dense(x)) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = F.relu(self.fc1(x)) x = F.dropout(x, p=self.relu_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn_source, attn_mask, self_attn_state return x, attn_source, attn_mask def set_attention_input_buffer(self, attention_layer, incremental_state, previous_attn_state): if previous_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = previous_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} attention_layer._set_input_buffer(incremental_state, saved_state) def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, 'cross_self_attention', False) if args.div or args.entmax: self.self_attn = ConstrainedMultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, args=args, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, cur_attn_type='ds') else: self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu')) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: self.activation_dropout = getattr(args, 'relu_dropout', 0) self.normalize_before = args.decoder_normalize_before export = getattr(args, 'char_inputs', False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: if args.div or args.entmax: self.self_attn = ConstrainedMultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, args=args, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True, cur_attn_type='ds', ) else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, need_attn=False, need_head_weights=False, ): if need_head_weights: need_attn = True residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] self.self_attn._set_input_buffer(incremental_state, saved_state) if self.cross_self_attention and not ( incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)): if self_attn_mask is not None: self_attn_mask = torch.cat((x.new( x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1) if self_attn_padding_mask is not None: if encoder_padding_mask is None: encoder_padding_mask = self_attn_padding_mask.new( encoder_out.size(1), encoder_out.size(0)).zero_() self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1) y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) if self_attn_padding_mask is not None: self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"], saved_state["prev_key_padding_mask"] else: self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.embed_dim = args.decoder_embed_dim embed_dim = self.embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) # self.dropout = [0.05, 0.1, 0.25, 0.3] # self.dropout = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3] self.dropout = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3] self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) # self.self_attn_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim]) self.self_attn_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16), int(embed_dim * 7 / 16), int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16), int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim]) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim]) # self.fc1 = SLinear([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim], # [int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4),int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim]) # self.fc2 = SLinear([int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4), int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim], # [int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim]) # self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) # self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) # self.final_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim]) self.encoder_attn_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16), int(embed_dim * 7 / 16), int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16), int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim]) self.fc1 = SLinear(embed_dim, args.encoder_ffn_embed_dim) self.fc2 = SLinear(args.encoder_ffn_embed_dim, embed_dim) self.final_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16), int(embed_dim * 7 / 16), int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16), int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim]) self.final_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16), int(embed_dim * 7 / 16), int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16), int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim]) self.linear_list = [int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16), int(embed_dim * 7 / 16), int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16), int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim] self.ffn_list = [int(args.encoder_ffn_embed_dim * 4 / 16), int(args.encoder_ffn_embed_dim * 5 / 16), int(args.encoder_ffn_embed_dim * 6 / 16), int(args.encoder_ffn_embed_dim * 7 / 16), int(args.encoder_ffn_embed_dim * 8 / 16), int(args.encoder_ffn_embed_dim * 9 / 16), int(args.encoder_ffn_embed_dim * 10 / 16), int(args.encoder_ffn_embed_dim * 11 / 16), int(args.encoder_ffn_embed_dim * 12 / 16), int(args.encoder_ffn_embed_dim * 13 / 16), int(args.encoder_ffn_embed_dim * 14 / 16), int(args.encoder_ffn_embed_dim * 15 / 16), args.encoder_ffn_embed_dim] self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, idx, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ # pdb.set_trace() if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x, index[0]) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( index=idx, query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout[idx[0]], training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x, idx[0]) if self.encoder_attn is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x, idx[0]) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( index=idx, query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout[idx[0]], training=self.training) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x, idx[0]) residual = x if self.normalize_before: x = self.final_layer_norm(x, idx[0]) x = self.activation_fn(self.fc1(x, self.linear_list[idx[0]], self.ffn_list[idx[1]])) x = F.dropout(x, p=self.dropout[idx[1]], training=self.training) x = self.fc2(x, self.ffn_list[idx[1]], self.linear_list[idx[0]]) x = F.dropout(x, p=self.dropout[idx[0]], training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x, idx[0]) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn def compute_macs_params(self, T=1, S=1): macs = 0 n_params = 0 macs_attn = 0 # LayerNorm n_params += sum([p.numel() for p in self.self_attn_layer_norm.parameters()]) n_params += sum([p.numel() for p in self.final_layer_norm.parameters()]) # self attention self_attn_layer = self.self_attn.compute_macs_params(T=T, S=T) macs += self_attn_layer['macs'] n_params += self_attn_layer['params'] macs_attn += self_attn_layer['macs_attn'] # Encoder-decoder attn if self.encoder_attn is not None: # self attention scaled-dot-product Attn enc_attn = self.encoder_attn.compute_macs_params(T=T, S=S) macs += enc_attn['macs'] n_params += enc_attn['params'] macs_attn += enc_attn['macs_attn'] # FFN fc1_params = sum([p.numel() for p in self.fc1.parameters()]) macs += (fc1_params * T) n_params += (fc1_params) fc2_params = sum([p.numel() for p in self.fc2.parameters()]) macs += (fc2_params * T) n_params += fc2_params return { 'name': self.__class__.__name__, 'macs': macs, 'params': n_params, 'macs_attn': macs_attn }
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False, copyNet=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False # self.copyNet = copyNet if self.copyNet: self.target_encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.incorpor_weights_W = Linear(args.decoder_embed_dim, 1) self.incorpor_weights_U = Linear(args.decoder_embed_dim, 1) else: self.target_encoder_attn = None self.incorpor_weights_W = None self.incorpor_weights_U = None def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, TM=None, TM_padding=None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) source_attn = None retrieve_attn = None p_copy = None if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, source_attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) if self.target_encoder_attn is not None: assert TM.size(0) > 1 and TM.size(1) == x.size(1) and TM.size(2) == x.size(2), "TM: {}, x: {}".format(TM.size(), x.size()) assert TM.size(0) ==TM_padding.size(1) and TM.size(1) == TM_padding.size(0), "TM: {}, TM_padding: {}".format(TM.size(), TM_padding.size()) target_x, retrieve_attn = self.target_encoder_attn( query=x, key=TM, value=TM, key_padding_mask=TM_padding, incremental_state=incremental_state, static_kv=True, need_weights=True, ) p_copy = torch.sigmoid(self.incorpor_weights_W(x) + self.incorpor_weights_U(target_x)) p_copy = p_copy.transpose(0, 1) # T x B x 1 -> B x T x 1 x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = F.relu(self.fc1(x)) x = F.dropout(x, p=self.relu_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state["prev_value"] return x, source_attn, self_attn_state return x, source_attn, retrieve_attn, p_copy def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, 'cross_self_attention', False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout bi_context_attn = getattr(args, 'input_form', None) self.bi_context_attn = (bi_context_attn == 'sep') self.share_key_proj = getattr(args, 'sep_attn_share_key_proj', False) self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu') ) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, 'relu_dropout', 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, 'char_inputs', False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: if not self.bi_context_attn: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True, qkv_same_dim=not self.share_key_proj, ) # share key proj is query actually self.aug_encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True, shared_q_proj_weight=self.encoder_attn.q_proj_weight if self.share_key_proj else None, qkv_same_dim=not self.share_key_proj, ) self.context_value_weight = getattr(args, 'ctx_value_weight', 0.5) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, need_attn=False, need_head_weights=False, bi_context=None, bi_context_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] self.self_attn._set_input_buffer(incremental_state, saved_state) if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)): if self_attn_mask is not None: self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1) if self_attn_padding_mask is not None: if encoder_padding_mask is None: encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_() self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1) y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) if self.bi_context_attn: bx, battn = self.aug_encoder_attn( query=x, key=bi_context, value=bi_context, key_padding_mask=bi_context_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = (1. - self.context_value_weight) * x + self.context_value_weight * bx x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) if self_attn_padding_mask is not None: self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"] else: self_attn_state = saved_state["prev_key"], saved_state["prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TarcTransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, num_cross_attentions=0, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.num_cross_attentions = num_cross_attentions self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) # This is my main modification: cross-attentions to attend the other decoder outputs self.cross_attentions = nn.ModuleList() self.cross_attentions_norm = nn.ModuleList() for i in range( num_cross_attentions ): self.cross_attentions.append( MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=self.embed_dim, vdim=self.embed_dim, dropout=args.attention_dropout, encoder_decoder_attention=True, ) ) self.cross_attentions_norm.append( LayerNorm(self.embed_dim, export=export) ) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out: Optional[List[torch.Tensor]] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, prev_cross_attn_state: Optional[List[List[torch.Tensor]]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True assert len(self.cross_attentions)+1 == len(encoder_out) residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out[0] is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out[0].size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out[0] is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out[0].size(1), encoder_out[0].size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out[0] is not None y = torch.cat((encoder_out[0], x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) cross_attn_x = x if self.encoder_attn is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out[0], value=encoder_out[0], key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x) if self.num_cross_attentions > 0: residual = cross_attn_x all_att_output = torch.zeros_like(cross_attn_x) if self.normalize_before: cross_attn_x = self.cross_attentions_norm[0](cross_attn_x) for i in range( len(self.cross_attentions) ): if prev_cross_attn_state is not None: prev_key, prev_value = prev_cross_attn_state[i][:2] cross_saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_cross_attn_state[i]) >= 3: cross_saved_state["prev_key_padding_mask"] = prev_cross_attn_state[i][2] assert incremental_state is not None self.cross_attentions[i]._set_input_buffer(incremental_state, cross_saved_state) att_output, attn = self.cross_attentions[i]( query=cross_attn_x, key=encoder_out[i+1], value=encoder_out[i+1], key_padding_mask=None, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) att_output = F.dropout(att_output, p=self.dropout, training=self.training) all_att_output = att_output + all_att_output if self.encoder_attn is not None: x = x + all_att_output # encoder_attn and cross_attentions use the same residual, so no need to add it twice else: x = residual + x + all_att_output if not self.normalize_before: x = self.cross_attentions_norm[0](x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn @torch.jit.export def reorder_incremental_state( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor, ): """Scriptable reorder incremental state in transformer layers.""" self.self_attn.reorder_incremental_state(incremental_state, new_order) if self.encoder_attn is not None: self.encoder_attn.reorder_incremental_state(incremental_state, new_order) if self.num_cross_attentions > 0: [attn.reorder_incremental_state(incremental_state, new_order) for attn in self.cross_attentions]
class AttentionDecoderLayer(nn.Module): def __init__( self, embed_dim, attention_heads, self_attention=True, add_bias_kv=False, add_zero_attn=False, attention_dropout=0.1, dropout=0.3, normalize_before=False, ): super().__init__() self.embed_dim = embed_dim self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=attention_heads, dropout=attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True) if self_attention else None self.dropout = dropout self.normalize_before = normalize_before self.encoder_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=attention_heads, kdim=None, vdim=None, dropout=attention_dropout, encoder_decoder_attention=True) if not self_attention else None self.attn_layer_norm = LayerNorm(self.embed_dim, export=False) def forward(self, x, encoder_out=None, incremental_state=None, encoder_padding_mask=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, need_attn=False, need_head_weights=False, **unused): # encoder_out = None # encoder_padding_mask = None # if encoder_outs is not None: # if 'encoder_out' in encoder_outs.keys(): # encoder_out = encoder_outs['encoder_out'] # if 'encoder_padding_mask' in encoder_outs.keys(): # encoder_padding_mask = encoder_outs['encoder_padding_mask'] residual = x x = self.maybe_layer_norm(x, before=True) if self.self_attn is not None: if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_self_attn_state) >= 3: saved_state[ "prev_key_padding_mask"] = prev_self_attn_state[2] self.self_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask) if self.encoder_attn is not None: if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.encoder_attn(query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=False, need_head_weights=False) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(x, after=True) return x, None def maybe_layer_norm(self, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return self.attn_layer_norm(x) else: return x # import torch # device = torch.device("cuda:0") # x = torch.rand(4, 2, 8).to(device) # encoder = AttentionDecoderLayer(8, 4).to(device) # mask = torch.zeros(2, 4).bool().to(device) # y = encoder(x, incremental_state=None) # print(y.size())
class TransformerSentenceEmbeddingDecoderLayer(TransformerDecoderLayer): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, add_bias_kv=False, add_zero_attn=False, do_trans=True): super().__init__(args) self.embed_dim = args.decoder_embed_dim self.args = args self.do_trans = do_trans self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu')) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, 'relu_dropout', 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, 'char_inputs', False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False if do_trans: self.decoder_fc1 = Linear(self.embed_dim + self.args.latent_size, self.embed_dim) else: self.decoder_fc1 = Linear( self.embed_dim + self.args.latent_size * 2, self.embed_dim) def forward( self, x, sent_emb=None, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) size = (x.size()[0], x.size()[1], sent_emb.size()[-1]) concat_sent_emb = torch.cat((x, sent_emb.expand(size)), dim=2) x = self.decoder_fc1(concat_sent_emb) F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn
class HybridRNNDecoder(FairseqIncrementalDecoder): """ Transformer decoder consisting of *args.decoder_layers* layers. Each layer is a :class:`TransformerDecoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout self.share_input_output_embed = args.share_decoder_input_output_embed embed_dim = embed_tokens.embedding_dim self.embed_tokens = embed_tokens self.lstm_units = args.decoder_lstm_units self.num_layers = args.decoder_layers self.initial_input_dim = embed_dim self.encoder_output_dim = args.encoder_embed_dim if args.decoder_reduced_attention_dim is None: self.attention_dim = self.encoder_output_dim else: self.attention_dim = args.decoder_reduced_attention_dim self.input_dim = self.lstm_units + self.attention_dim self.num_attention_heads = args.decoder_attention_heads self.bottleneck_dim = args.decoder_out_embed_dim self.initial_rnn_layer = nn.LSTM( input_size=self.initial_input_dim, hidden_size=self.lstm_units ) self.initial_layernorm = LayerNorm(self.lstm_units) self.proj_encoder_layer = None if self.attention_dim != self.encoder_output_dim: self.proj_encoder_layer = Linear( self.encoder_output_dim, self.attention_dim ) self.proj_layer = None if self.lstm_units != self.attention_dim: self.proj_layer = Linear( self.lstm_units, self.attention_dim ) self.attention = MultiheadAttention( self.attention_dim, self.num_attention_heads, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.extra_rnn_layers = nn.ModuleList([]) self.extra_layernorms = nn.ModuleList([]) for _ in range(self.num_layers - 1): self.extra_rnn_layers.append( nn.LSTM(input_size=self.input_dim, hidden_size=self.lstm_units) ) self.extra_layernorms.append( LayerNorm(self.lstm_units) ) self.bottleneck_layer = None if self.bottleneck_dim is not None: self.out_embed_dim = self.bottleneck_dim self.bottleneck_layer = Linear( self.input_dim, self.out_embed_dim ) else: self.out_embed_dim = self.input_dim if not self.share_input_output_embed: self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.out_embed_dim)) nn.init.normal_(self.embed_out, mean=0, std=self.out_embed_dim ** -0.5) else: assert self.bottleneck_dim == args.decoder_embed_dim, (self.bottleneck_dim, args.decoder_embed_dim) def _unpack_encoder_out(self, encoder_out): """ Allow taking encoder_out from different architecture which may have different formats. """ # return encoder_out['encoder_out'], encoder_out['encoder_padding_mask'] return encoder_out.encoder_out, encoder_out.encoder_padding_mask def _init_hidden(self, encoder_out, batch_size): """ Initialize with latent code if available otherwise zeros.""" return torch.zeros([1, batch_size, self.lstm_units]) def _concat_latent_code(self, x, encoder_out): """ Concat latent code, if available in encoder_out, which is the case in subclass. """ return x def _embed_prev_outputs(self, prev_output_tokens, incremental_state=None): if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) return x, prev_output_tokens def forward( self, prev_output_tokens, encoder_out, incremental_state=None, possible_translation_tokens=None, timestep=None, ): x, prev_output_tokens = self._embed_prev_outputs( prev_output_tokens=prev_output_tokens, incremental_state=incremental_state ) return self._forward_given_embeddings( embed_out=x, prev_output_tokens=prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, possible_translation_tokens=possible_translation_tokens, timestep=timestep, ) def _forward_given_embeddings( self, embed_out, prev_output_tokens, encoder_out, incremental_state=None, possible_translation_tokens=None, timestep=None, ): x = embed_out (encoder_x, encoder_padding_mask) = self._unpack_encoder_out(encoder_out) bsz, seqlen = prev_output_tokens.size() state_outputs = [] if incremental_state is not None: prev_states = utils.get_incremental_state( self, incremental_state, "cached_state" ) if prev_states is None: prev_states = self._init_prev_states(encoder_out) # final 2 states of list are projected key and value saved_state = {"prev_key": prev_states[-2], "prev_value": prev_states[-1]} self.attention._set_input_buffer(incremental_state, saved_state) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[0] c_prev = prev_states[1] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x = self._concat_latent_code(x, encoder_out) x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev)) x = self.initial_layernorm(x) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) if self.proj_encoder_layer is not None: encoder_x = self.proj_encoder_layer(encoder_x) attention_in = x if self.proj_layer is not None: attention_in = self.proj_layer(x) attention_out, attention_weights = self.attention( query=attention_in, key=encoder_x, value=encoder_x, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training), ) for i, layer in enumerate(self.extra_rnn_layers): residual = x rnn_input = torch.cat([x, attention_out], dim=2) rnn_input = self._concat_latent_code(rnn_input, encoder_out) if incremental_state is not None: # first num_layers pairs of states are (prev_hidden, prev_cell) # for each layer h_prev = prev_states[2 * i + 2] c_prev = prev_states[2 * i + 3] else: h_prev = self._init_hidden(encoder_out, bsz).type_as(x) c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x) x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev)) if incremental_state is not None: state_outputs.extend([h_next, c_next]) x = F.dropout(x, p=self.dropout, training=self.training) x = x + residual x = self.extra_layernorms[i](x) x = torch.cat([x, attention_out], dim=2) x = self._concat_latent_code(x, encoder_out) if self.bottleneck_layer is not None: x = self.bottleneck_layer(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.share_input_output_embed: logits = F.linear(x, self.embed_tokens.weight) else: logits = F.linear(x, self.embed_out) if incremental_state is not None: # encoder projections can be reused at each incremental step state_outputs.extend([prev_states[-2], prev_states[-1]]) utils.set_incremental_state( self, incremental_state, "cached_state", state_outputs ) return logits, attention_weights def max_positions(self): """Maximum output length supported by the decoder.""" return int(1024) # an arbitrary large number def _init_prev_states(self, encoder_out): """ Initial (hidden, cell) values for LSTM layers are zero. For encoder-decoder attention, key and value are computed once from the encoder outputs and stay the same throughout decoding. """ (encoder_x, encoder_padding_mask) = self._unpack_encoder_out(encoder_out) batch_size = torch.onnx.operators.shape_as_tensor(encoder_x)[1] if self.proj_encoder_layer is not None: encoder_x = self.proj_encoder_layer(encoder_x) states = [] for _ in range(self.num_layers): hidden = self._init_hidden(encoder_out, batch_size).type_as(encoder_x) cell = torch.zeros([1, batch_size, self.lstm_units]).type_as(encoder_x) states.extend([hidden, cell]) # (key, value) for encoder-decoder attention computed from encoder # output and remain the same throughout decoding key = self.attention.k_proj(encoder_x) value = self.attention.v_proj(encoder_x) # (key, value) kept in shape (bsz, num_heads, seq_len, head_dim) # to avoid repeated transpose operations seq_len, batch_size_int, _ = encoder_x.shape num_heads = self.attention.num_heads head_dim = self.attention.head_dim key = ( key.view(seq_len, batch_size_int * num_heads, head_dim) .transpose(0, 1) .view(batch_size_int, num_heads, seq_len, head_dim) ) value = ( value.view(seq_len, batch_size_int * num_heads, head_dim) .transpose(0, 1) .view(batch_size_int, num_heads, seq_len, head_dim) ) states.extend([key, value]) return states def reorder_incremental_state(self, incremental_state, new_order): # parent reorders attention model super().reorder_incremental_state(incremental_state, new_order) cached_state = utils.get_incremental_state( self, incremental_state, "cached_state" ) if cached_state is None: return # Last 2 elements of prev_states are encoder projections # used for ONNX export for i, state in enumerate(cached_state[:-2]): cached_state[i] = state.index_select(1, new_order) for i in [-2, -1]: cached_state[i] = cached_state[i].index_select(0, new_order) utils.set_incremental_state( self, incremental_state, "cached_state", cached_state )
class TransformerAANDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs. Default: ``False`` """ def __init__(self, args, no_encoder_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.more_dropouts = args.decoder_aan_more_dropouts if args.decoder_attn_window_size <= 0: self.avg_attn = AverageAttention(self.embed_dim, dropout=args.attention_dropout) else: self.avg_attn = AverageWindowAttention( self.embed_dim, dropout=args.attention_dropout, window_size=args.decoder_attn_window_size, ) # self.activation = getattr(args, "decoder_ffn_activation", "relu") self.aan_layer_norm = LayerNorm(self.embed_dim) if args.no_decoder_aan_ffn: self.aan_ffn = None else: aan_ffn_hidden_dim = (args.decoder_ffn_embed_dim if args.decoder_aan_ffn_use_embed_dim else args.decoder_ffn_embed_dim) self.aan_ffn = FeedForwardNetwork( self.embed_dim, aan_ffn_hidden_dim, self.embed_dim, num_layers=2, dropout=args.relu_dropout, ) if args.no_decoder_aan_gating: self.aan_gating_fc = None else: self.aan_gating_fc = Linear(self.embed_dim * 2, self.embed_dim * 2) self.normalize_before = args.decoder_normalize_before if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=args.encoder_embed_dim, vdim=args.encoder_embed_dim, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.ffn = FeedForwardNetwork( self.embed_dim, args.decoder_ffn_embed_dim, self.embed_dim, num_layers=2, dropout=args.relu_dropout, ) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out, encoder_padding_mask, incremental_state, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x if "residual" in self.more_dropouts: residual = F.dropout(residual, p=self.dropout, training=self.training) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.avg_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.avg_attn( value=x, mask_future_timesteps=True, incremental_state=incremental_state, mask_trick=self.training, ) if "after_avg" in self.more_dropouts: x = F.dropout(x, p=self.dropout, training=self.training) if self.aan_layer_norm is not None: x = self.maybe_layer_norm(self.aan_layer_norm, x, before=True) if self.aan_ffn is not None: x = self.aan_ffn(x) if "after_ffn" in self.more_dropouts: x = F.dropout(x, p=self.dropout, training=self.training) if self.aan_gating_fc is not None: i, f = self.aan_gating_fc(torch.cat([residual, x], dim=-1)).chunk(2, dim=-1) x = torch.sigmoid(f) * residual + torch.sigmoid(i) * x if "after_gating" in self.more_dropouts: x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if self.aan_layer_norm is not None: x = self.maybe_layer_norm(self.aan_layer_norm, x, after=True) attn = None if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.ffn(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn def extra_repr(self): return "dropout={}, more_dropouts={}".format(self.dropout, self.more_dropouts)
class TransformerDecoderLayerPhase2(nn.Module): """Second phase of decoder layer block This layer will take the input from the ecoder and phirst pass decoder. papers.nips.cc/paper/6775-deliberation-networks-sequence-generation-beyond-one-pass-decoding.pdf """ def __init__( self, args, no_encoder_decoder_attn=False, add_bias_kv=False, add_zero_attn=False, ): super().__init__() self.embed_dim = args.decoder_embed_dim self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu")) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_decoder_attn: self.encoder_attn = None self.decoder_attn = None self.encoder_layer_norm = None self.decoder_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.decoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.decoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, decoder_out=None, incremental_state=None, prev_self_attn_state=None, prev_encoder_attn_state=None, prev_decoder_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x x_self_attention = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x_self_attention, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x_self_attention = F.dropout(x_self_attention, p=self.dropout, training=self.training) x_self_attention = residual + x_self_attention x_self_attention = self.maybe_layer_norm(self.self_attn_layer_norm, x_self_attention, after=True) if self.encoder_attn is not None: residual = x x_encoder_attention = self.maybe_layer_norm( self.encoder_attn_layer_norm, x, before=True) if prev_encoder_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_encoder_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x_encoder_attention, attn = self.encoder_attn( query=x_encoder_attention, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x_encoder_attention = F.dropout(x_encoder_attention, p=self.dropout, training=self.training) x_encoder_attention = residual + x_encoder_attention x_encoder_attention = self.maybe_layer_norm( self.encoder_attn_layer_norm, x_encoder_attention, after=True) if self.decoder_attn is not None: residual = x x_decoder_attention = self.maybe_layer_norm( self.decoder_attn_layer_norm, x, before=True) if prev_decoder_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_decoder_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x_decoder_attention, attn = self.decoder_attn( query=x_decoder_attention, key=decoder_out, value=decoder_out, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x_decoder_attention = F.dropout(x_decoder_attention, p=self.dropout, training=self.training) x_decoder_attention = residual + x_decoder_attention x_decoder_attention = self.maybe_layer_norm( self.encoder_attn_layer_norm, x_decoder_attention, after=True) x = x_self_attention + x_encoder_attention + x_decoder_attention residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class AANDecoderLayer(nn.Module): """ Based on https://arxiv.org/abs/1805.00631 """ def __init__(self, args): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.avg_attn = AverageAttention(self.embed_dim, dropout=args.attention_dropout) # differently than original paper, we use a single gate self.aan_gating_fc = fairseq_transformer.Linear( self.embed_dim * 2, self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu")) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.avg_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = fairseq_transformer.Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = fairseq_transformer.Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, need_attn=False, need_head_weights=False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` The following are used for export tracing: prev_self_attn_state: [prev_sum, prev_pos] assumes AverageAttention without mask trick prev_attn_state: [prev_key, prev_value] """ if need_head_weights: need_attn = True residual = x x = self.maybe_layer_norm(self.avg_attn_layer_norm, x, before=True) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_sum, prev_pos = prev_self_attn_state # (batch, embed) -> (seq, batch, embed) prev_sum = prev_sum.unsqueeze(0) saved_state = {"prev_sum": prev_sum, "prev_pos": prev_pos} self.avg_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.avg_attn( value=x, mask_future_timesteps=True, incremental_state=incremental_state, mask_trick=self.training, ) # differently than original paper, we use a single gate gate = torch.sigmoid( self.aan_gating_fc(torch.cat([residual, x], dim=-1))) x = gate * x + (1 - gate) * residual x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.avg_attn_layer_norm, x, after=True) if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.avg_attn._get_input_buffer(incremental_state) # remove sequence axis for export prev_sum = saved_state["prev_sum"] # (seq, batch, embed) -> (batch, embed) prev_sum = prev_sum.squeeze(0) prev_pos = saved_state["prev_pos"] self_attn_state = prev_sum, prev_pos return x, attn, self_attn_state return x, attn, None def maybe_layer_norm(self, layer_norm, x, before=False, after=False): assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, LayerNum=None): super().__init__() global tmp_file self.args = args if not hasattr(self.args, 'mixed_precision'): self.args.mixed_precision = False if not hasattr(self.args, 'plot_variance'): self.args.plot_variance = False if not hasattr(self.args, 'plot_gradient'): self.args.plot_gradient = False self.normalize_before = args.decoder_normalize_before self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, 'cross_self_attention', False) self.layer_num = LayerNum if 'adaptive' in args.init_type: assert not self.normalize_before self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention) assert not no_encoder_attn self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) if 'adaptive-profiling' == args.init_type: if not tmp_file: tmp_file = open('profile.ratio.init', 'w') self.self_ratio_change = nn.Parameter( torch.ones(self.embed_dim)) self.encoder_ratio_change = nn.Parameter( torch.ones(self.embed_dim)) self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim)) else: if not tmp_file: tmp_file = open('profile.ratio.init', 'r') layer_iter, next_value = [ float(tup) for tup in tmp_file.readline().split() ] print('layer_num: {}, layer_iter: {}'.format( self.layer_num, layer_iter)) assert layer_iter == 3 * self.layer_num + 1 print('decoder self ratio: {}'.format(next_value)) self.self_ratio_change = nn.Parameter( torch.ones(self.embed_dim)) self.self_ratio_change.data.fill_(next_value) layer_iter, next_value = [ float(tup) for tup in tmp_file.readline().split() ] print('layer_num: {}, layer_iter: {}'.format( self.layer_num, layer_iter)) assert layer_iter == 3 * self.layer_num + 2 print('decoder en ratio: {}'.format(next_value)) self.encoder_ratio_change = nn.Parameter( torch.ones(self.embed_dim)) self.encoder_ratio_change.data.fill_(next_value) layer_iter, next_value = [ float(tup) for tup in tmp_file.readline().split() ] print('layer_num: {}, layer_iter: {}'.format( self.layer_num, layer_iter)) assert layer_iter == 3 * self.layer_num + 3 print('decoder ffn ratio: {}'.format(next_value)) self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim)) self.fc_ratio_change.data.fill_(next_value) export = getattr(args, 'char_inputs', False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) else: self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention) assert not no_encoder_attn self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, 'encoder_embed_dim', None), vdim=getattr(args, 'encoder_embed_dim', None), dropout=args.attention_dropout, encoder_decoder_attention=True) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) if args.init_type == 'looklinear': self.fc1.weight.data[int(args.decoder_ffn_embed_dim / 2):, :] = -self.fc1.weight.data[ 0:int(args.decoder_ffn_embed_dim / 2), :] self.fc2.weight.data[:, int(args.decoder_ffn_embed_dim / 2):] = -self.fc2.weight.data[:, 0:int( args.decoder_ffn_embed_dim / 2)] export = getattr(args, 'char_inputs', False) if args.init_type != 'rezero': self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) else: self.self_attn_layer_norm = None self.encoder_attn_layer_norm = None self.final_layer_norm = None if 'rezero' in args.init_type: self.rezero_weight = nn.Parameter(torch.Tensor([0])) else: assert args.init_type == 'default' self.rezero_weight = None self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, 'activation_fn', 'relu')) self.activation_dropout = getattr(args, 'activation_dropout', 0) if self.activation_dropout == 0: self.activation_dropout = getattr(args, 'relu_dropout', 0) self.need_attn = True self.onnx_trace = False if args.fp16: self.in_type = torch.half else: self.in_type = torch.float def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out=None, encoder_padding_mask=None, incremental_state=None, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None, need_attn=False, need_head_weights=False, ): not_initialized = ('adaptive-profiling' == self.args.init_type) and ( 1.0 == self.self_ratio_change.min()) and self.training if need_head_weights: need_attn = True residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if self.args.mixed_precision: x = x.type(self.in_type) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] self.self_attn._set_input_buffer(incremental_state, saved_state) if self.cross_self_attention and not ( incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)): if self_attn_mask is not None: self_attn_mask = torch.cat((x.new( x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1) if self_attn_padding_mask is not None: if encoder_padding_mask is None: encoder_padding_mask = self_attn_padding_mask.new( encoder_out.size(1), encoder_out.size(0)).zero_() self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1) y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) if self.args.mixed_precision: x = x.float() if 'adaptive' in self.args.init_type: if not_initialized: global decoder_ratio, tmp_file tmp_layer_ind = self.layer_num * 3 + 1 tmp_ratio = decoder_ratio tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio)) self.self_ratio_change.data.fill_(tmp_ratio) print('decoder self attn ratio: {}'.format(tmp_ratio)) input_std = np.var((residual * self.self_ratio_change).clone( ).cpu().float().data.contiguous().view(-1).numpy()) output_std = np.var( x.clone().cpu().float().data.contiguous().view(-1).numpy()) decoder_ratio = np.sqrt(input_std + output_std) x0 = x + residual * self.self_ratio_change elif self.rezero_weight is not None: x0 = residual + self.rezero_weight * x else: x0 = residual + x x0 = self.maybe_layer_norm(self.self_attn_layer_norm, x0, after=True) if self.args.plot_gradient: x0.register_hook(lambda grad: print('{} decoder s-att: {}'.format( self.layer_num, grad.norm().item()))) x = x0 if self.encoder_attn is not None: residual = x x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x0, before=True) if self.args.mixed_precision: x = x.type(self.in_type) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) if self.args.mixed_precision: x = x.float() if 'adaptive' in self.args.init_type: if not_initialized: tmp_layer_ind = self.layer_num * 3 + 2 tmp_ratio = decoder_ratio tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio)) self.encoder_ratio_change.data.fill_(tmp_ratio) print('decoder encoder attn ratio: {}'.format(tmp_ratio)) input_std = np.var( (residual * self.encoder_ratio_change).clone().cpu( ).float().data.contiguous().view(-1).numpy()) output_std = np.var( x.clone().cpu().float().data.contiguous().view( -1).numpy()) decoder_ratio = np.sqrt(input_std + output_std) x1 = x + residual * self.encoder_ratio_change elif self.rezero_weight is not None: x1 = residual + self.rezero_weight * x else: x1 = residual + x x1 = self.maybe_layer_norm(self.encoder_attn_layer_norm, x1, after=True) if self.args.plot_gradient: x1.register_hook(lambda grad: print( '{} decoder e-att: {}'.format(self.layer_num, grad.norm().item()))) x = x1 residual = x x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) if self.args.mixed_precision: x = x.type(self.in_type) bx = self.fc1(x) hx = self.activation_fn(bx) x = F.dropout(hx, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) if self.args.mixed_precision: x = x.float() if 'adaptive' in self.args.init_type: if not_initialized: tmp_layer_ind = self.layer_num * 3 + 3 tmp_ratio = decoder_ratio tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio)) self.fc_ratio_change.data.fill_(tmp_ratio) print('decoder ffn ratio: {}'.format(tmp_ratio)) input_var = np.var((residual * self.fc_ratio_change).clone( ).cpu().float().data.contiguous().view(-1).numpy()) output_var = np.var( x.clone().cpu().float().data.contiguous().view(-1).numpy()) decoder_ratio = np.sqrt(input_var + output_var) x2 = x + residual * self.fc_ratio_change elif self.rezero_weight is not None: x2 = residual + self.rezero_weight * x else: x2 = residual + x x2 = self.maybe_layer_norm(self.final_layer_norm, x2, after=True) if self.args.plot_gradient: x2.register_hook(lambda grad: print('{} decoder ffn: {}'.format( self.layer_num, grad.norm().item()))) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) if self_attn_padding_mask is not None: self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"], saved_state["prev_key_padding_mask"] else: self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x2, attn, self_attn_state return x2, attn def maybe_layer_norm(self, layer_norm, x, before=False, after=False): if self.args.init_type == 'rezero': return x assert before ^ after if after ^ self.normalize_before: return layer_norm(x) else: return x def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn
class FixupTransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def fixup_initialization(self, args): temp_state_dic = {} de_layers = args.decoder_layers if args.Tfixup: for name, param in self.named_parameters(): if name in [ "fc1.weight", "fc2.weight", "self_attn.out_proj.weight", "encoder_attn.out_proj.weight", ]: temp_state_dic[name] = (9 * de_layers)**(-1. / 4.) * param elif name in [ "self_attn.v_proj.weight", "encoder_attn.v_proj.weight", ]: temp_state_dic[name] = (9 * de_layers)**(-1. / 4.) * (param * (2**0.5)) for name in self.state_dict(): if name not in temp_state_dic: temp_state_dic[name] = self.state_dict()[name] self.load_state_dict(temp_state_dic) def __init__(self, args, no_encoder_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim if args.max_relative_length == -1: self.self_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) else: self.self_attn = RelativeMultiheadAttention( self.embed_dim, args.decoder_attention_heads, args.max_relative_length, dropout=args.attention_dropout, k_only=args.k_only, ) self.dropout = args.dropout self.relu_dropout = args.relu_dropout self.normalize_before = args.decoder_normalize_before if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.noLN = args.dont_use_layernorm if not self.noLN: self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self_attn_padding_mask=None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. Returns: encoded output of shape `(batch, src_len, embed_dim)` """ residual = x if self.normalize_before and not self.noLN: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_self_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.self_attn._set_input_buffer(incremental_state, saved_state) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before and not self.noLN: x = self.self_attn_layer_norm(x) attn = None if self.encoder_attn is not None: residual = x if self.normalize_before and not self.noLN: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: if incremental_state is None: incremental_state = {} prev_key, prev_value = prev_attn_state saved_state = {"prev_key": prev_key, "prev_value": prev_value} self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=(not self.training and self.need_attn), ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before and not self.noLN: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before and not self.noLN: x = self.final_layer_norm(x) x = F.relu(self.fc1(x)) x = F.dropout(x, p=self.relu_dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before and not self.noLN: x = self.final_layer_norm(x) if self.onnx_trace: saved_state = self.self_attn._get_input_buffer(incremental_state) self_attn_state = saved_state["prev_key"], saved_state[ "prev_value"] return x, attn, self_attn_state return x, attn def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn