Ejemplo n.º 1
0
    def forward(self,
                trg_embed: Tensor = None,
                encoder_output=None,
                src_mask: Tensor = None,
                trg_mask: Tensor = None,
                **kwargs):
        """
        Transformer decoder forward pass.

        :param trg_embed: embedded targets
        :param encoder_output: source representations
        :param src_mask:
        :param trg_mask: to mask out target paddings
                         Note that a subsequent mask is applied here.
        :return:
            context attn probs: batch x layer x head x tgt x src
        """
        assert trg_mask is not None, "trg_mask is required for Transformer"

        x = self.pe(trg_embed)  # add position encoding to word embedding
        x = self.emb_dropout(x)

        trg_mask = trg_mask & subsequent_mask(
            trg_embed.size(1)).type_as(trg_mask)

        for layer in self.layers:
            x = layer(x=x,
                      memory=encoder_output.states,
                      src_mask=src_mask,
                      trg_mask=trg_mask)

        x = self.layer_norm(x)
        output = self.output_layer(x)

        return output, x, None, None
Ejemplo n.º 2
0
    def forward(
            self,
            trg_embed: Tensor = None,
            encoder_outputs: dict = None,  # note the change
            src_mask: Tensor = None,
            trg_mask: Tensor = None,
            inflection_mask: Tensor = None,
            **kwargs):
        """
        Transformer decoder forward pass.

        :param trg_embed: embedded targets
        :param encoder_output: source representations
        :param src_mask:
        :param trg_mask: to mask out target paddings
                         Note that a subsequent mask is applied here.
        :return:
            context attn probs: batch x layer x head x tgt x src
        """
        assert trg_mask is not None, "trg_mask is required for Transformer"

        # encoder_outputs is not an ideal name for this dict because it also
        # contains encoder_hidden, which transformers do not use
        enc_outputs = {k: v[0] for k, v in encoder_outputs.items()}

        x = self.pe(trg_embed)  # add position encoding to word embedding
        x = self.emb_dropout(x)

        trg_mask = trg_mask & subsequent_mask(
            trg_embed.size(1)).type_as(trg_mask)

        self_attn_layers = []
        context_attn_layers = defaultdict(list)
        for layer in self.layers:
            # todo: multiple (variable?) numbers of context attentions
            # return a dictionary of context attentions
            x, self_attn, ctx_attn = layer(x=x,
                                           memories=enc_outputs,
                                           src_mask=src_mask,
                                           trg_mask=trg_mask,
                                           inflection_mask=inflection_mask)
            self_attn_layers.append(self_attn)
            for enc_name, enc_ctx_attn in ctx_attn.items():
                context_attn_layers[enc_name].append(enc_ctx_attn)

        # todo: inflection_trg attention
        trg_trg_attn = torch.stack(self_attn_layers, dim=1)
        attn = {"trg_trg": trg_trg_attn}
        for enc_name, enc_ctx_attn_layers in context_attn_layers.items():
            enc_stacked_attn = torch.stack(enc_ctx_attn_layers, dim=1)
            attn[enc_name + "_trg"] = enc_stacked_attn

        x = self.layer_norm(x)
        output = self.output_layers["vocab"](x)

        return output, x, attn, None
Ejemplo n.º 3
0
    def forward(self,
                trg_embed: Tensor = None,
                encoder_output: Tensor = None,
                encoder_hidden: Tensor = None,
                src_mask: Tensor = None,
                unroll_steps: int = None,
                hidden: Tensor = None,
                trg_mask: Tensor = None,
                **kwargs):
        """
        Transformer decoder forward pass.

        :param trg_embed: embedded targets
        :param encoder_output: source representations
        :param encoder_hidden: unused
        :param src_mask:
        :param unroll_steps: unused
        :param hidden: unused
        :param trg_mask: to mask out target paddings
                         Note that a subsequent mask is applied here.
        :param kwargs:
        :return:
        """
        assert trg_mask is not None, "trg_mask required for Transformer"

        x = self.pe(trg_embed)  # add position encoding to word embedding
        x = self.emb_dropout(x)

        trg_mask = trg_mask & subsequent_mask(
            trg_embed.size(1)).type_as(trg_mask)

        for layer in self.layers:
            x = layer(x=x,
                      memory=encoder_output,
                      src_mask=src_mask,
                      trg_mask=trg_mask)

        x = self.layer_norm(x)

        # input(('OUTPUT DIM BEFORE', x.shape))
        output = self.output_layer(x)
        # input(('OUTPUT DIM AFTER', output.shape))

        # output, hidden, att_scores, att_vectors # <--- last three dont matter for transformer
        return output, x, None, None
Ejemplo n.º 4
0
    def forward(self,
                trg_embed: Tensor = None,
                encoder_output: Tensor = None,
                encoder_hidden: Tensor = None,
                src_mask: Tensor = None,
                unroll_steps: int = None,
                hidden: Tensor = None,
                trg_mask: Tensor = None,
                **kwargs):
        """
        Transformer decoder forward pass.

        :param trg_embed: embedded targets
        :param encoder_output: source representations
        :param encoder_hidden: unused
        :param src_mask:
        :param unroll_steps: unused
        :param hidden: unused
        :param trg_mask: to mask out target paddings
                         Note that a subsequent mask is applied here.
        :param kwargs:
        :return:
        """
        assert trg_mask is not None, "trg_mask required for Transformer"

        x = self.pe(trg_embed)  # add position encoding to word embedding
        x = self.emb_dropout(x)

        trg_mask = trg_mask & subsequent_mask(
            trg_embed.size(1)).type_as(trg_mask)

        inbetween_layers = []
        for layer in self.layers:
            x = layer(x=x,
                      memory=encoder_output,
                      src_mask=src_mask,
                      trg_mask=trg_mask)
            inbetween_layers.append(x)

        x = self.layer_norm(x)
        output = self.output_layer(x)
        #output = self.final_layer(output)
        return output, x, None, None
Ejemplo n.º 5
0
    def _forward_pre_output(self,
                            trg_embed: Tensor = None,
                            encoder_output: Tensor = None,
                            src_mask: Tensor = None,
                            trg_mask: Tensor = None,
                            **kwargs):
        """
        Transformer decoder forward pass.

        :param trg_embed: embedded targets
        :param encoder_output: source representations
        :param src_mask:
        :param trg_mask: to mask out target paddings
                         Note that a subsequent mask is applied here.
        :return:
            context attn probs: batch x layer x head x tgt x src
        """
        assert trg_mask is not None, "trg_mask is required for Transformer"

        x = self.pe(trg_embed)  # add position encoding to word embedding
        x = self.emb_dropout(x)

        trg_mask = trg_mask & subsequent_mask(
            trg_embed.size(1)).type_as(trg_mask)

        self_attn_layers = []
        context_attn_layers = []
        for layer in self.layers:
            x, self_attn, ctx_attn = layer(x=x,
                                           memory=encoder_output,
                                           src_mask=src_mask,
                                           trg_mask=trg_mask)
            self_attn_layers.append(self_attn)
            context_attn_layers.append(ctx_attn)

        x = self.layer_norm(x)

        trg_trg_attn = torch.stack(self_attn_layers, dim=1)
        src_trg_attn = torch.stack(context_attn_layers, dim=1)
        attn = {"trg_trg": trg_trg_attn, "src_trg": src_trg_attn}

        return x, attn, None