Exemple #1
0
    def __init__(self, args):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)

        self.avg_attn = AverageAttention(self.embed_dim,
                                         dropout=args.attention_dropout)

        # differently than original paper, we use a single gate
        self.aan_gating_fc = fairseq_transformer.Linear(
            self.embed_dim * 2, self.embed_dim)

        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.avg_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.encoder_attn = MultiheadAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            kdim=getattr(args, "encoder_embed_dim", None),
            vdim=getattr(args, "encoder_embed_dim", None),
            dropout=args.attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = fairseq_transformer.Linear(self.embed_dim,
                                              args.decoder_ffn_embed_dim)
        self.fc2 = fairseq_transformer.Linear(args.decoder_ffn_embed_dim,
                                              self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False
Exemple #2
0
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.more_dropouts = args.decoder_aan_more_dropouts
        if args.decoder_attn_window_size <= 0:
            self.avg_attn = AverageAttention(self.embed_dim,
                                             dropout=args.attention_dropout)
        else:
            self.avg_attn = AverageWindowAttention(
                self.embed_dim,
                dropout=args.attention_dropout,
                window_size=args.decoder_attn_window_size,
            )
        # self.activation = getattr(args, "decoder_ffn_activation", "relu")
        self.aan_layer_norm = LayerNorm(self.embed_dim)
        if args.no_decoder_aan_ffn:
            self.aan_ffn = None
        else:
            aan_ffn_hidden_dim = (args.decoder_ffn_embed_dim
                                  if args.decoder_aan_ffn_use_embed_dim else
                                  args.decoder_ffn_embed_dim)
            self.aan_ffn = FeedForwardNetwork(
                self.embed_dim,
                aan_ffn_hidden_dim,
                self.embed_dim,
                num_layers=2,
                dropout=args.relu_dropout,
            )

        if args.no_decoder_aan_gating:
            self.aan_gating_fc = None
        else:
            self.aan_gating_fc = Linear(self.embed_dim * 2, self.embed_dim * 2)
        self.normalize_before = args.decoder_normalize_before

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=args.encoder_embed_dim,
                vdim=args.encoder_embed_dim,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.ffn = FeedForwardNetwork(
            self.embed_dim,
            args.decoder_ffn_embed_dim,
            self.embed_dim,
            num_layers=2,
            dropout=args.relu_dropout,
        )

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False