class TransformerDecoderLayer(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.decoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.embed_dim = args.decoder_embed_dim self.cross_self_attention = getattr(args, "cross_self_attention", False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.decoder_normalize_before # use layerNorm rather than FusedLayerNorm for exporting. # char_inputs can be used to determint this. # TODO remove this once we update apex with the fix export = getattr(args, "char_inputs", False) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = MultiheadAttention( self.embed_dim, args.decoder_attention_heads, kdim=getattr(args, "encoder_embed_dim", None), vdim=getattr(args, "encoder_embed_dim", None), dropout=args.attention_dropout, encoder_decoder_attention=True, ) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, export=export) self.need_attn = True self.onnx_trace = False def prepare_for_onnx_export_(self): self.onnx_trace = True def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn def compute_macs_params(self, T=1, S=1): macs = 0 n_params = 0 macs_attn = 0 # LayerNorm n_params += sum([p.numel() for p in self.self_attn_layer_norm.parameters()]) n_params += sum([p.numel() for p in self.final_layer_norm.parameters()]) # self attention self_attn_layer = self.self_attn.compute_macs_params(T=T, S=T) macs += self_attn_layer['macs'] n_params += self_attn_layer['params'] macs_attn += self_attn_layer['macs_attn'] # Encoder-decoder attn if self.encoder_attn is not None: # self attention scaled-dot-product Attn enc_attn = self.encoder_attn.compute_macs_params(T=T, S=S) macs += enc_attn['macs'] n_params += enc_attn['params'] macs_attn += enc_attn['macs_attn'] # FFN fc1_params = sum([p.numel() for p in self.fc1.parameters()]) macs += (fc1_params * T) n_params += (fc1_params) fc2_params = sum([p.numel() for p in self.fc2.parameters()]) macs += (fc2_params * T) n_params += fc2_params return { 'name': self.__class__.__name__, 'macs': macs, 'params': n_params, 'macs_attn': macs_attn }
class TransformerEncoderLayer(nn.Module): """Encoder layer block. In the original paper each operation (multi-head attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *args.encoder_normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments """ def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim self.self_attn = MultiheadAttention( self.embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True, ) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( activation=getattr(args, "activation_fn", "relu") ) self.activation_dropout = getattr(args, "activation_dropout", 0) if self.activation_dropout == 0: # for backwards compatibility with models that use args.relu_dropout self.activation_dropout = getattr(args, "relu_dropout", 0) self.normalize_before = args.encoder_normalize_before self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim) def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layer_norms.{}.{}".format(name, old, m) if k in state_dict: state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where T_tgt is the length of query, while T_src is the length of key, though here both query and key is x here, attn_mask[t_tgt, t_src] = 1 means when calculating embedding for t_tgt, t_src is excluded (or masked out), =0 means it is included in attention Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if attn_mask is not None: attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters # TODO: to formally solve this problem, we need to change fairseq's # MultiheadAttention. We will do this later on. x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.self_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = F.dropout(x, p=float(self.activation_dropout), training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) return x def compute_macs_params(self, S=1): macs = 0 n_params = 0 macs_attn = 0 # Layer Norms # MACS are zero for LayerNorm because they can be fused n_params += sum([p.numel() for p in self.self_attn_layer_norm.parameters()]) n_params += sum([p.numel() for p in self.final_layer_norm.parameters()]) # Attn self_attn_layer = self.self_attn.compute_macs_params(T=S, S=S) macs += self_attn_layer['macs'] n_params += self_attn_layer['params'] macs_attn += self_attn_layer['macs_attn'] # FFN fc1_params = sum([p.numel() for p in self.fc1.parameters()]) # scale MACS by S because S tokens can be processed in parallel macs += (fc1_params * S) n_params += fc1_params fc2_params = sum([p.numel() for p in self.fc2.parameters()]) # scale MACS by S because S tokens can be processed in parallel macs += (fc2_params * S) n_params += fc2_params return { 'name': self.__class__.__name__, 'macs': macs, 'params': n_params, 'macs_attn': macs_attn }