Example #1
0
class TransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """

    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, embed_dim, self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args)
            for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm:
            x = self.layer_norm(x)
        # exam source tokens 
        '''
        print('src_tokens')
        print(src_tokens[:5])
        print('src_dict')
        print(self.dictionary.string(src_tokens[:5]))
        '''
        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            #--------------------------------------------------------
            #'src_tokens':self.dictionary.string(src_tokens),
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}")

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #2
0
class TransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions
        self.agg_method = args.agg_method
        self.agg_layers = args.agg_layers

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args) for i in range(args.encoder_layers)
        ])

        self.attn = MultiheadAttention(embed_dim,
                                       args.encoder_attention_heads,
                                       dropout=args.attention_dropout,
                                       encoder_decoder_attention=True)
        self.fc = Linear(args.agg_layers * embed_dim, embed_dim)

        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        prev_x = []
        for idx, layer in enumerate(self.layers):
            x = layer(x, encoder_padding_mask)
            # if idx != len(self.layers)-1:
            prev_x.append(x)
        # history = torch.cat(prev_x, 2)
        # history = self.activation_fn(self.fc(history))
        # history = F.dropout(history, p=self.activation_dropout, training=self.training)
        # x, _ = self.attn(query=history, key=x, value=x, key_padding_mask=encoder_padding_mask)
        prev_x = prev_x[-1 - self.agg_layers:-1]
        if self.agg_method == 'add':
            for his in prev_x:
                x += his
        elif self.agg_method == 'fc':
            his = self.activation_fn(self.fc(torch.cat(prev_x, 2)))
            his = F.dropout(his,
                            p=self.activation_dropout,
                            training=self.training)
            x, _ = self.attn(query=his,
                             key=x,
                             value=x,
                             key_padding_mask=encoder_padding_mask)
        elif self.agg_method == 'attn':
            his = prev_x[0]
            for idx in range(1, len(prev_x)):
                his, _ = self.attn(query=his,
                                   key=prev_x[idx],
                                   value=prev_x[idx],
                                   key_padding_mask=encoder_padding_mask)
                his = F.dropout(his, p=self.dropout, training=self.training)
            x, _ = self.attn(query=his,
                             key=x,
                             value=x,
                             key_padding_mask=encoder_padding_mask)
        x = F.dropout(x, p=self.dropout, training=self.training)
        if self.layer_norm:
            x = self.layer_norm(x)
        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def reorder_encoder_input(self, encoder_input, new_order):
        """
                Reorder encoder input according to *new_order*.

                Args:
                    encoder_input: output from the ``forward()`` method
                    new_order (LongTensor): desired order

                Returns:
                    *encoder_input* rearranged according to *new_order*
                """
        if encoder_input['src_tokens'] is not None:
            encoder_input['src_tokens'] = \
                encoder_input['src_tokens'].index_select(0, new_order)
        return encoder_input

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
class TransformerEncoderAug(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """

    def __init__(self, args, dictionary, embed_tokens,lm):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))
        self.src_lm = lm
        self.acl_drop = args.sca_drop

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, embed_dim, self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args)
            for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

        self.ln1 = Linear(2 *embed_dim, embed_dim)
        self.ln2 = Linear(2*len(dictionary), len(dictionary))
        self.ln3 = Linear(embed_dim, embed_dim,bias = False)

    def forward(self, src_tokens, src_tokens_lm, src_lengths ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        ## compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        mask = encoder_padding_mask.eq(0).type(torch.FloatTensor).unsqueeze(dim = -1).cuda()
        if not encoder_padding_mask.any():
            mask = None
            encoder_padding_mask = None

        src_lm = src_tokens_lm
        #with torch.no_grad():
        if self.training:
            prop = self.acl_drop
            word_drop = (torch.rand(src_tokens.size()) > prop).type(torch.LongTensor)
            lm_drop = word_drop.eq(0).type(torch.FloatTensor)
            lm_drop = lm_drop.unsqueeze(dim = -1)
            lm_ = src_lm * lm_drop.cuda()
            x_ = src_tokens * word_drop.cuda()
            x1 = self.embed_tokens(x_)
            x2 = F.linear(lm_,self.embed_tokens.weight.t())
            x = self.embed_scale *(x1+x2)

        else:
            x = self.embed_scale * self.embed_tokens(src_tokens)



        x3 = F.linear(src_tokens_lm, self.src_lm.embed_tokens.weight.t())
        x3 = F.dropout(x3, p = 0.1, training =self.training)

        x = self.ln1(torch.cat((x,x3),dim = -1))

        #x = x + x4



        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

                # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}")

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #4
0
class TransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout
        self.encoder_layerdrop = args.encoder_layerdrop

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens

        self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(
            embed_dim)

        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layer_wise_attention = getattr(args, 'layer_wise_attention',
                                            False)

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args, layer_id=i)
            for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None
        if getattr(args, 'layernorm_embedding', False):
            self.layernorm_embedding = LayerNorm(embed_dim)
        else:
            self.layernorm_embedding = None

    def forward_embedding(self, src_tokens):
        # embed tokens and positions
        embed = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x = embed + self.embed_positions(src_tokens)
        if self.layernorm_embedding:
            x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x, embed

    def forward(self,
                src_tokens,
                src_lengths,
                cls_input=None,
                return_all_hiddens=False,
                **unused):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        if self.layer_wise_attention:
            return_all_hiddens = True

        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if not self.training or (dropout_probability >
                                     self.encoder_layerdrop):
                x = layer(x, encoder_padding_mask)
                if return_all_hiddens:
                    encoder_states.append(x)

        if self.layer_norm:
            x = self.layer_norm(x)
            if return_all_hiddens:
                encoder_states[-1] = x

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
        )

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out.encoder_out is not None:
            encoder_out = encoder_out._replace(
                encoder_out=encoder_out.encoder_out.index_select(1, new_order))
        if encoder_out.encoder_padding_mask is not None:
            encoder_out = encoder_out._replace(
                encoder_padding_mask=encoder_out.encoder_padding_mask.
                index_select(0, new_order))
        if encoder_out.encoder_embedding is not None:
            encoder_out = encoder_out._replace(
                encoder_embedding=encoder_out.encoder_embedding.index_select(
                    0, new_order))
        if encoder_out.encoder_states is not None:
            for idx, state in enumerate(encoder_out.encoder_states):
                encoder_out.encoder_states[idx] = state.index_select(
                    1, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(
                self, '_future_mask'
        ) or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
            if self._future_mask.size(0) < dim:
                self._future_mask = torch.triu(
                    utils.fill_with_neg_inf(self._future_mask.resize_(
                        dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                print('deleting {0}'.format(weights_key))
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #5
0
class TransformerAvgEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """

    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, embed_dim, self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args)
            for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

        #image section

        self.img_dim = 2048
        self.text_dim = embed_dim
        self.L2norm = args.L2norm
        self.total_num_img = args.total_num_img
        self.per_num_img = args.per_num_img

        # cap2image_file = args.cap2image_file
        # image_embedding_file = args.image_embedding_file
        
        cap2image_file = getattr(args, "cap2image_file", "data/cap2image.pickle")
        image_embedding_file = getattr(args, "image_embedding_file", "features_resnet50/train-resnet50-avgpool.npy")

        self.cap2image = pickle.load(open(cap2image_file, "rb"))  #cap_id to image_id

        #print("image embedding processing...")
        embeding_weights = np.load(image_embedding_file)
        img_vocab, img_dim = embeding_weights.shape
        embeddings_matrix = np.zeros((img_vocab + 1, img_dim))
        embeddings_matrix[1:] = embeding_weights
        self.img_embeddings = nn.Embedding.from_pretrained(torch.FloatTensor(embeddings_matrix),
                                                           freeze=args.image_emb_fix)  # update embedding

        # self.img_embeddings.load_state_dict({'weight': embeddings_matrix})
        # if args.image_emb_fix:
        #     self.img_embeddings.weight.requires_grad = False
        self.merge_option = args.merge_option
        self.dense = nn.Linear(self.img_dim, self.text_dim)

        self.mergeImage = nn.Linear(self.total_num_img, 1)
        if self.merge_option == "att-mul-concat":
            self.proj_attention = SCAttention(self.text_dim, 128)
            self.dense2 = nn.Linear(self.text_dim, 384)
        elif self.merge_option == "att-concat":
            self.dense2 = nn.Linear(2 * self.text_dim, self.text_dim)
        elif self.merge_option == "att-gate":
            self.gate_type = args.gate_type
            self.proj_attention = SCAttention(self.text_dim, self.text_dim)
            if self.gate_type == "neural-gate":
                self.sigmoid = nn.Sigmoid()
                self.gate_dense = nn.Linear(2*self.text_dim, self.text_dim)
            elif self.gate_type == "scalar-gate":
                self.sigmoid = nn.Sigmoid()
                self.gate_dense = nn.Linear(2*self.text_dim, 1)
            else:
                self.image_weight = args.image_weight

        else:
            self.proj_attention = SCAttention(self.text_dim, self.text_dim)

    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """

        srl_tok_list = src_tokens.tolist()
        batch_image_ids = []
        for batch_idx, sent in enumerate(srl_tok_list):
            # token2image
            image_ids = []
            for cap in sent:
                if cap in self.cap2image:
                    for id in self.cap2image[cap][:self.per_num_img]:
                        if id != 0:
                            image_ids.append(id)
            image_freq= Counter(image_ids)
            image_sort = sorted(image_freq.items(), key=lambda x: x[1], reverse=True)
            image_ids = [item[0] for idx, item in enumerate(image_sort) if idx < self.total_num_img]
            #image_ids = image_ids[:self.total_num_img]
            # Zero-pad up to the sequence length.
            padding_length = self.total_num_img - len(image_ids)
            image_ids = image_ids + ([0] * padding_length)
            batch_image_ids.append(image_ids)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        batch_image_ids = torch.LongTensor(batch_image_ids).to(device)
        #image embedding
        batch_size, num_img = batch_image_ids.size()
        #print(batch_image_ids[0])
        image_padding_mask = batch_image_ids.eq(0)
        #print(image_padding_mask[0])
        image_mask = ~image_padding_mask
        # print(image_mask[0])

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        # print(src_tokens)
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        text_mask = ~encoder_padding_mask

        #print(encoder_padding_mask.size())
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm:
            x = self.layer_norm(x)

        # image_ids_flat = image_ids.view(batch_size, -1)
        image_ids_flat = batch_image_ids #batch_size x num_image
        image_embedding = self.img_embeddings(image_ids_flat)
        image_embedding = image_embedding.view(batch_size, num_img, self.img_dim) # batch_size, num_img, dim

        #L2 norm
        if self.L2norm == "true":
            image_embedding = F.normalize(image_embedding, p=2, dim=1)

        # attention on each local region
        text_repr = x.transpose(0, 1)  # T x B x C -> batch_size, seq_len, dim
        image_repr = self.dense(image_embedding)  # batch_size, num_img, image_dim - > text_dim

        if self.merge_option == "biatt":
            output = self.proj_attention(image_repr, image_mask, text_repr, text_mask)
            output = self.proj_attention(text_repr, text_mask, output, image_mask)  #batch_size, seq_len, dim
        elif self.merge_option == "att":
            output = self.proj_attention(text_repr, text_mask, image_repr, image_mask)  #batch_size, seq_len, dim
        elif self.merge_option == "att-sum":
            output = self.proj_attention(text_repr, text_mask, image_repr, image_mask)  # batch_size, seq_len, dim
            output = text_repr + output #0.5, 1
        elif self.merge_option == "att-gate":
            output = self.proj_attention(text_repr, text_mask, image_repr, image_mask)  # batch_size, seq_len, dim
            if self.gate_type == "neural-gate":
                merge = torch.cat([text_repr, output], dim=-1)
                gate = self.sigmoid(self.gate_dense(merge))
                output = (1 - gate)*text_repr + gate*output
                #print("neural-gate")
            elif self.gate_type == "scalar-gate":
                merge = torch.cat([text_repr, output], dim=-1)
                gate = self.sigmoid(self.gate_dense(merge))
                output = (1 - gate)*text_repr + gate*output
                #print("scalar-gate")
            else:
                output = 1.0*text_repr + self.image_weight*output
        else:
            output = None
        #print(output.size())
        x = output.transpose(0, 1)  # batch_size, seq_len, dim -> T x B x C
        #print(x)
        # print(x.size())
        # print(encoder_padding_mask.size())
        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #6
0
class JointAttentionEncoder(FairseqEncoder):
    """
    JointAttention encoder is used only to compute the source embeddings.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
        left_pad (bool): whether the input is left-padded
    """
    def __init__(self, args, dictionary, embed_tokens, left_pad):
        super().__init__(dictionary)
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None
        self.embed_language = LanguageEmbedding(
            embed_dim) if args.language_embeddings else None

        self.register_buffer('version', torch.Tensor([2]))

    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): embedding output of shape
                  `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        # language embedding
        if self.embed_language is not None:
            lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1)
            x += lang_emb
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())
class TransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """

    def __init__(self, args, dictionary, embed_tokens, max_source_positions,
                 encoder_layers, encoder_embed_dim,
                 encoder_attention_heads, encoder_ffn_embed_dim,):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout

        if embed_tokens is None:
            self.padding_idx = 0
            embed_dim = encoder_embed_dim
        else:
            self.padding_idx = embed_tokens.padding_idx
            embed_dim = embed_tokens.embedding_dim
        self.max_source_positions = max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            max_source_positions, embed_dim, self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(encoder_embed_dim,
                 encoder_attention_heads, encoder_ffn_embed_dim, args)
            for i in range(encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

        self.use_seg_pos_emb = getattr(args, 'use_seg_pos_emb', 0)
        if self.use_seg_pos_emb:
            self.seg_pad_idx = 2
            self.seg_pos_emb = Embedding(3, embed_dim, padding_idx=self.seg_pad_idx)

    def forward_embedding(self, src_tokens, seg_pos=-1):
        # embed tokens and positions
        embed = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x = embed + self.embed_positions(src_tokens)
        if self.use_seg_pos_emb:
            masked_src_tokens = src_tokens.masked_fill(src_tokens.ne(self.padding_idx), seg_pos)
            masked_src_tokens = masked_src_tokens.masked_fill(src_tokens.eq(self.padding_idx), self.seg_pad_idx)
            x = x + self.embed_scale * self.seg_pos_emb(masked_src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x, embed

    def forward(self, src_tokens=None, cls_input=None, return_all_hiddens=False, src_encodings=None,
                encoder_padding_mask=None, attn_mask=None, auxilary_tokens=None):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            src_encodings (torch.FloatTensor): shape of `(T x B x C)`
            encoder_padding_mask (torch.Boolean): shape of '(B x T)', where paddings are True
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        if self.layer_wise_attention:
            return_all_hiddens = True

        if self.embed_tokens is not None:
            x, encoder_embedding = self.forward_embedding(src_tokens, seg_pos=0)
            aug_x, _ = self.forward_embedding(auxilary_tokens, seg_pos=1)
            x = torch.cat([x, aug_x], dim=1)
            # B x T x C -> T x B x C
            x = x.transpose(0, 1)
            src_tokens = torch.cat([src_tokens, auxilary_tokens], dim=1)
            # compute padding mask
            encoder_padding_mask = (src_tokens.eq(self.padding_idx) | src_tokens.eq(self.dictionary.bos_index))
        else:
            assert encoder_padding_mask is not None
            src_tokens = encoder_padding_mask.long()
            encoder_embedding = None
            x = src_encodings
            if self.embed_positions is not None:
                x = x + self.embed_positions(src_tokens).transpose(0, 1)
            x = F.dropout(x, p=self.dropout, training=self.training)

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask, attn_mask=attn_mask)
            if return_all_hiddens:
                encoder_states.append(x)

        if self.layer_norm:
            x = self.layer_norm(x)
            if return_all_hiddens:
                encoder_states[-1] = x

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            'encoder_embedding': encoder_embedding,  # B x T x C
            'encoder_states': encoder_states,  # List[T x B x C]
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        if encoder_out.get('encoder_states', None) is not None:
            for idx, state in enumerate(encoder_out['encoder_states']):
                encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
            if self._future_mask.size(0) < dim:
                self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #8
0
class TaLKConvEncoder(FairseqEncoder):
    """
    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TaLKConvEncoderLayer(args,
                                 kernel_size=args.encoder_kernel_size_list[i])
            for i in range(args.encoder_layers)
        ])
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

        self.acts_reg = []

    def forward(self, src_tokens, **unused):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())
Example #9
0
class LightConvEncoder(FairseqEncoder):
    """
    LightConv encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`LightConvEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions
        self.encoder_embed_dim = args.encoder_embed_dim

        #self.bi_rnn_layer = torch.nn.GRU(
        #    args.encoder_embed_dim,
        #    args.encoder_embed_dim,
        #    num_layers=1,
        #    batch_first=True,
        #    bidirectional=True
        #)
        self.rnn_layer = torch.nn.GRU(args.encoder_embed_dim,
                                      args.encoder_embed_dim,
                                      num_layers=4,
                                      dropout=0.1,
                                      batch_first=True,
                                      bidirectional=False)

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            LightConvEncoderLayer(args,
                                  kernel_size=args.encoder_kernel_size_list[i])
            for i in range(args.encoder_layers)
        ])
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

    def forward(self, src_tokens, **unused):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """

        # We are in a character-level settings, let's add some char-rnn

        # embed tokens and positions
        src_tokens_temp = self.embed_tokens(src_tokens)
        #bi_output, _ = self.bi_rnn_layer(src_tokens_temp)
        # Concatenate (PLUS) the bidirectional encoder rnn outputs so that we have the original embedding size
        #src_tokens_temp = bi_output[:, :, :self.encoder_embed_dim] + bi_output[:, :, self.encoder_embed_dim:]
        # Average values
        #src_tokens_temp = src_tokens_temp / 2.0
        src_tokens_temp, _ = self.rnn_layer(src_tokens_temp)
        x = self.embed_scale * src_tokens_temp
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        for layer in self.layers:
            # TODO(naetherm):residual_x = x
            x = layer(x, encoder_padding_mask)
            # TODO(naetherm):x = residual_x + x

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())
Example #10
0
class TransformerEncoderC(nn.Module):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.
    Controller stream

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, vocab, embed_tokens):
        super().__init__()
        self.vocab = vocab

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        )

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayerC(args) for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.g_layer_norm = LayerNorm(embed_dim)
        else:
            self.g_layer_norm = None

    def forward(self,
                src_tokens,
                encoder_mode="gumbel",
                encoder_temperature=-1,
                need_weights=False,
                **unused):
        # embed tokens and positions
        embedding = self.embed_tokens(src_tokens)
        g = self.embed_scale * embedding
        g += self.embed_positions(src_tokens)
        g = F.dropout(g, p=self.dropout, training=self.training)
        # B x T x C -> T x B x C
        g = g.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        attn_list = []
        attn_data_list = []
        for layer in self.layers:
            g, attn, attn_data = layer(
                g,
                encoder_padding_mask=encoder_padding_mask,
                encoder_mode=encoder_mode,
                encoder_temperature=encoder_temperature,
                need_weights=need_weights)
            attn_list.append(attn)
            attn_data_list.append(attn_data)

        if self.g_layer_norm:
            g = self.g_layer_norm(g)

        return {
            'encoder_g': g,  # T x B x C
            'encoder_attn': attn_list,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            'encoder_attn_data_list': attn_data_list
        }

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())
Example #11
0
class TransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args) for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

        self.add_template = args.add_template

        if self.add_template:
            self.template_layers = nn.ModuleList([])
            self.template_layers.extend([
                TransformerEncoderLayer(args)
                for i in range(args.encoder_layers)
            ])

            if args.encoder_normalize_before:
                self.tp_layer_norm = LayerNorm(embed_dim)
            else:
                self.tp_layer_norm = None

            self.positionwise = PositionWise(embed_dim, embed_dim,
                                             self.dropout)

            self.two_encoder_mix = nn.Linear(2 * embed_dim, embed_dim)

            self.attention = MultiheadAttention(embed_dim,
                                                args.encoder_attention_heads,
                                                dropout=args.attention_dropout)

    def forward(self, src_tokens, src_lengths, template=None):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm:
            x = self.layer_norm(x)

        if self.add_template:
            tp = self.embed_scale * self.embed_tokens(template)
            if self.embed_positions is not None:
                tp += self.embed_positions(template)
            tp = F.dropout(tp, p=self.dropout, training=self.training)

            # B x T x C -> T x B x C
            tp = tp.transpose(0, 1)

            # compute padding mask
            tp_encoder_padding_mask = template.eq(self.padding_idx)

            # encoder layers
            for layer in self.template_layers:
                tp = layer(tp, tp_encoder_padding_mask)

            if self.tp_layer_norm:
                tp = self.tp_layer_norm(tp)

            adj_att, _ = self.attention(
                query=x,
                key=tp,
                value=tp,
                key_padding_mask=tp_encoder_padding_mask)
            adj_att = F.dropout(adj_att,
                                p=self.dropout,
                                training=self.training)
            adj_egd_cat = torch.cat([adj_att, x], dim=-1)
            two_encoder = self.two_encoder_mix(adj_egd_cat)
            gate = torch.sigmoid(two_encoder)
            output = gate.mul(adj_att) + (1 - gate).mul(x)

            x = self.positionwise(output)

        if encoder_padding_mask is not None:
            mean_mask = 1 - encoder_padding_mask
            mean_mask = mean_mask.unsqueeze(2).repeat(1, 1,
                                                      x.size()[2]).transpose(
                                                          0, 1).float()
            adj_att_mean = mean_mask * adj_att
            adj_att_mean = torch.mean(adj_att_mean, dim=0)
        else:
            adj_att_mean = torch.mean(adj_att, dim=0)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            'tp_mean': adj_att_mean,
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
class SimMTTransformerMultiPassEncoder(FairseqEncoder):
    """
    SimMTTransformerMultiPass encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayerOurs`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """

    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, embed_dim, self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayerOurs(args)
            for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

        self.wait_k = args.wait_k

    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        # unfold, forward T' pass
        # pdb.set_trace()
        t = x.shape[0]
        t_fw = max(t - self.wait_k + 1, 1)  # T': t forward
        # padding mask
        encoder_padding_mask = encoder_padding_mask.unsqueeze(1).repeat(1, t_fw, 1)  # B x T  =>  B x T' x T
        # mask time
        time_mask = (torch.arange(t)[None, :] > torch.arange(t)[:, None]).to(x.device)[min(self.wait_k, t) - 1:, :]  # T' x T
        encoder_padding_mask = (encoder_padding_mask | time_mask[None, :, :])  # B x T' x T
        encoder_padding_mask = encoder_padding_mask.view(-1, t)  # (B * T') x T
        # feature
        x = x.unsqueeze(2).repeat(1, 1, t_fw, 1)  # T x B x C  =>  T x B x T' x C
        x = x.view(t, -1, x.shape[-1])  # T x (B * T') x C

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm:
            x = self.layer_norm(x)

        # for each sample, expand to a long vector
        x = x.view(t, -1, t_fw, x.shape[-1]).permute(2, 0, 1, 3)  # T x B x T' x C  =>  T' x T x B x C
        x = x[~time_mask, :, :]  # Tfinal x B x C
        forward_idx = torch.arange(t_fw).to(x.device).unsqueeze(1).repeat(1, t)[~time_mask]
        # Tfinal, indicating which token belongs to which forward
        encoder_padding_mask = encoder_padding_mask.view(-1, t_fw, t)[:, ~time_mask]

        return {
            'encoder_out': x,
            'encoder_padding_mask': encoder_padding_mask,
            'forward_idx': forward_idx,
            'src_lengths': src_lengths,
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        # pdb.set_trace()
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        encoder_out['src_lengths'] = encoder_out['src_lengths'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions, self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #13
0
class TransformerDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        self.output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(
            embed_dim)  # todo: try with input_embed_dim

        # calculate copy probability p(z=1) batch
        self.copy = args.copy

        self.project_in_dim = Linear(
            input_embed_dim, embed_dim,
            bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args, no_encoder_attn)
            for _ in range(args.decoder_layers)
        ])
        if self.copy:
            self.copy_attn = MultiheadAttention(
                embed_dim,
                1,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.linear_copy = Linear(embed_dim, 1)

        self.adaptive_softmax = None

        self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
            if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                self.output_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.adaptive_softmax_dropout,
                adaptive_inputs=embed_tokens
                if args.tie_adaptive_weights else None,
                factor=args.adaptive_softmax_factor,
                tie_proj=args.tie_adaptive_proj,
            )
        elif not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), self.output_embed_dim))
            nn.init.normal_(self.embed_out,
                            mean=0,
                            std=self.output_embed_dim**-0.5)

        if args.decoder_normalize_before and not getattr(
                args, 'no_decoder_final_norm', False):
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None,
                **unused):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        x, extra = self.extract_features(prev_output_tokens, encoder_out,
                                         incremental_state)
        x = self.output_layer(x)
        return x, extra

    def extract_features(self,
                         prev_output_tokens,
                         encoder_out=None,
                         incremental_state=None,
                         **unused):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        inner_states = [x]
        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out['encoder_out']
                if encoder_out is not None else None,
                encoder_out['encoder_padding_mask']
                if encoder_out is not None else None,
                incremental_state,
                self_attn_mask=self.buffered_future_mask(x)
                if incremental_state is None else None,
            )
            inner_states.append(x)

        if self.layer_norm:
            x = self.layer_norm(x)
        copy_x, copy_attn = None, None
        if self.copy:
            copy_x, copy_attn = self.copy_attn(
                query=x,
                key=encoder_out['encoder_out']
                if encoder_out is not None else None,
                value=encoder_out['encoder_out']
                if encoder_out is not None else None,
                key_padding_mask=encoder_out['encoder_padding_mask']
                if encoder_out is not None else None,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=True,
            )
            # copy_x = copy_x.transpose(0, 1)
        p_copy = None
        if self.copy:
            # p_copy = torch.sigmoid(self.linear_copy(copy_attn))
            p_copy = torch.sigmoid(self.linear_copy(x)).transpose(0, 1)
        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        # return x, {'attn': attn, 'inner_states': inner_states, 'p_copy': p_copy}
        return x, {
            'attn': attn,
            'inner_states': inner_states,
            'p_copy': p_copy,
            'copy_attn': copy_attn
        }

    def output_layer(self, features, **kwargs):
        """Project features to the vocabulary size."""
        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                return F.linear(features, self.embed_tokens.weight)
            else:
                return F.linear(features, self.embed_out)
        else:
            return features

    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]

        is_copy = 'p_copy' in net_output[1].keys(
        ) and net_output[1]['p_copy'] is not None
        # print(net_output[1]['attn'])
        if is_copy and False:
            p_copy = net_output[1]['p_copy']
            if 'net_input' in sample.keys():
                enc_seq_ids = sample['net_input']['src_tokens']
            else:
                # for decode step
                enc_seq_ids = sample['src_tokens']
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['copy_attn'].size(1), 1)
            generate_prob = utils.softmax(
                logits, dim=-1, onnx_trace=self.onnx_trace) * (1 - p_copy)
            copy_prob = net_output[1]['copy_attn'] * p_copy
            final = generate_prob.scatter_add(2, enc_seq_ids, copy_prob)
            if log_probs:
                return torch.log(final + 1e-15)
            else:
                return final
        else:
            if log_probs:
                return utils.log_softmax(logits,
                                         dim=-1,
                                         onnx_trace=self.onnx_trace)
            else:
                return utils.softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(
                self, '_future_mask'
        ) or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(
                0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'encoder_attn_layer_norm',
                '2': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = '{}.layers.{}.layer_norms.{}.{}'.format(
                        name, i, old, m)
                    if k in state_dict:
                        state_dict['{}.layers.{}.{}.{}'.format(
                            name, i, new, m)] = state_dict[k]
                        del state_dict[k]

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])

        return state_dict
Example #14
0
class JointAttentionDecoder(FairseqIncrementalDecoder):
    """
    JointAttention decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`ProtectedTransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        left_pad (bool, optional): whether the input is left-padded. Default:
            ``False``
    """

    def __init__(
            self,
            args,
            dictionary,
            embed_tokens,
            left_pad=False,
            final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed
        self.kernel_size_list = args.kernel_size_list

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)

        self.project_in_dim = Linear(
            input_embed_dim,
            embed_dim,
            bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, embed_dim, padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.embed_language = LanguageEmbedding(
            embed_dim) if args.language_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            ProtectedTransformerDecoderLayer(args, no_encoder_attn=True)
            for _ in range(args.decoder_layers)
        ])

        self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
            if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None

        if not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(
                    len(dictionary),
                    output_embed_dim))
            nn.init.normal_(
                self.embed_out,
                mean=0,
                std=output_embed_dim ** -0.5)
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.decoder_normalize_before and final_norm
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

        # self.skipped_layer = 0
        max_level = 10
        step = 1. / float(max_level)
        levels = [round(x * step, 2) for x in range(1, max_level + 1)]
        # automate
        self.skip_layers = self.jump_or_not(levels)
        # manual set
        # self.skip_layers = self.jump_or_not_manual()
        self.layer_drop_rate = 0.
        #
        self.scaling = False
        self.Formula = False
        self.source_target_drop = False
        print("SKip Layers :", self.skip_layers)
        print("layer_drop_rate:%f" % self.layer_drop_rate)

    def jump_or_not_manual(self):
        # manual
        all = {0.33: [0, 1, 4, 8, 12], 0.66: [0, 1, 2, 4, 6, 8, 10, 12], 1.0: [x for x in range(0, 14)]}
        # all = {0.33: [0, 1, 4, 8, 12], 0.66: [0, 1, 3, 5, 6, 8, 10, 12], 1.0: [x for x in range(0, 14)]}
        # all = {0.33: [0, 1, 3, 8], 0.66: [0, 1, 2, 4, 5, 8, 10], 1.0: [x for x in range(0, 14)]}
        return all

    def jump_or_not(self, levels):
        all = {}
        for level in levels:
            last_target_layer = int(len(self.layers) * level)
            step = round(len(self.layers) / last_target_layer)
            skip_layers = []
            for i in range(0, len(self.layers), step):
                skip_layers.append(i)
            # append last layer
            if skip_layers[-1] != len(self.layers) - 1:
                skip_layers.append(len(self.layers) - 1)
            all[level] = skip_layers
        return all

    def base_jump_or_not(self):
        # training time
        # if i == last_target_layer:
        #     break
        # inference only time, 3 losses
        # if not self.training and i == last_target_layer:
        #     return True
        # layerdrop-method
        # r = random.random()
        return True

    def layer_drop(self, i):
        p = self.layer_drop_rate
        #
        i += 1
        n = random.random()
        if self.Formula:
            pl = float((i / len(self.layers)) * (1. - p))
        else:
            pl = p
        return n <= pl

    def scale_whole_layer(self, i):
        i += 1
        p = self.layer_drop_rate
        # 2016 paper, scale down
        pl = 1 - float((i / len(self.layers)) * (1. - p))
        # 2019 paper Speech
        # pl = (i / len(self.layers)) * (1 - p)
        # pl = 1 / (1 - pl)
        return pl

    def forward(
            self,
            prev_output_tokens,
            encoder_out,
            incremental_state=None,
            level=1.):
        """
        Args:
            input (dict): with
                prev_output_tokens (LongTensor): previous decoder outputs of shape
                    `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the last decoder layer's output of shape `(batch, tgt_len,
                  vocab)`
                - the last decoder layer's attention weights of shape `(batch,
                  tgt_len, src_len)`
        """
        tgt_len = prev_output_tokens.size(1)

        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions

        # language embedding
        if self.embed_language is not None:
            lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1)
            x += lang_emb

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None
        inner_states = [x]
        source = encoder_out['encoder_out']
        process_source = incremental_state is None or len(
            incremental_state) == 0

        # extended padding mask
        source_padding_mask = encoder_out['encoder_padding_mask']
        if source_padding_mask is not None:
            target_padding_mask = source_padding_mask.new_zeros(
                (source_padding_mask.size(0), tgt_len))
            self_attn_padding_mask = torch.cat(
                (source_padding_mask, target_padding_mask), dim=1)
        else:
            self_attn_padding_mask = None

        # inference time
        if 'level' in encoder_out:
            level = encoder_out['level']
            # fix all the batches's level to be easy
            # level = 0.33

        # transformer layers
        for i, layer in enumerate(self.layers):
            # training with dropout - normal way
            if self.training and self.layer_drop(i):
                continue

            # skipping inference
            # if not self.training and i not in self.skip_layers[level]:
            #     continue
            #
            if self.kernel_size_list is not None:
                target_mask = self.local_mask(
                    x, self.kernel_size_list[i], causal=True, tgt_len=tgt_len)
            elif incremental_state is None:
                target_mask = self.buffered_future_mask(x)
            else:
                target_mask = None

            if target_mask is not None:
                zero_mask = target_mask.new_zeros(
                    (target_mask.size(0), source.size(0)))
                self_attn_mask = torch.cat((zero_mask, target_mask), dim=1)
            else:
                self_attn_mask = None

            # if self.source_target_drop and not self.layer_drop(i) or i == 0 or not self.training:
            state = incremental_state
            if process_source:
                if state is None:
                    state = {}
                if self.kernel_size_list is not None:
                    source_mask = self.local_mask(
                        source, self.kernel_size_list[i], causal=False)
                else:
                    source_mask = None
                source, attn = layer(
                    source,
                    None,
                    None,
                    state,
                    self_attn_mask=source_mask,
                    self_attn_padding_mask=source_padding_mask
                )
                inner_states.append(source)

            # if self.source_target_drop and not self.layer_drop(i) or not self.training:
            x, attn = layer(
                x,
                None,
                None,
                state,
                self_attn_mask=self_attn_mask,
                self_attn_padding_mask=self_attn_padding_mask
            )
            # x scaling
            if self.scaling:
                # training
                x = x * self.scale_whole_layer(i)
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        # project back to size of vocabulary
        if self.share_input_output_embed:
            x = F.linear(x, self.embed_tokens.weight)
        else:
            x = F.linear(x, self.embed_out)

        pred = x
        info = {'attn': attn, 'inner_states': inner_states}
        return pred, info

        # # 3 loss ways in layer 4,9,14 -> 0 + 8 + 10 + 10 == 8,18,28
        # if self.training:
        #     step = [8, 18, 28]
        #     # step = [28]
        # else:
        #     step = [len(inner_states) - 1]
        # loss_output = []
        # for i in step:
        #     x = inner_states[i]
        #     #
        #     if self.normalize:
        #         x = self.layer_norm(x)
        #
        #     # T x B x C -> B x T x C
        #     x = x.transpose(0, 1)
        #
        #     if self.project_out_dim is not None:
        #         x = self.project_out_dim(x)
        #
        #     # project back to size of vocabulary
        #     if self.share_input_output_embed:
        #         x = F.linear(x, self.embed_tokens.weight)
        #     else:
        #         x = F.linear(x, self.embed_out)
        #
        #     pred = x
        #     info = {'attn': attn, 'inner_states': inner_states}
        #     loss_output.append((pred, info))
        #
        # # return pred, info
        # if not self.training and len(loss_output) == 1:
        #     return loss_output[0]
        # return loss_output

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(
            self.max_target_positions,
            self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        """Cached future mask."""
        dim = tensor.size(0)
        # pylint: disable=access-member-before-definition,
        # attribute-defined-outside-init
        if not hasattr(
                self,
                '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(
                    tensor.new(
                        dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(
                    self._future_mask.resize_(
                        dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def local_mask(self, tensor, kernel_size, causal, tgt_len=None):
        """Locality constraint mask."""
        rows = tensor.size(0)
        cols = tensor.size(0) if tgt_len is None else tgt_len
        if causal:
            if rows == 1:
                mask = utils.fill_with_neg_inf(tensor.new(1, cols))
                mask[0, -kernel_size:] = 0
                return mask
            else:
                diag_u, diag_l = 1, kernel_size
        else:
            diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) //
                              2) if kernel_size % 2 == 1 else (kernel_size // 2, kernel_size // 2 + 1)
        mask1 = torch.triu(
            utils.fill_with_neg_inf(
                tensor.new(
                    rows, cols)), diag_u)
        mask2 = torch.tril(
            utils.fill_with_neg_inf(
                tensor.new(
                    rows, cols)), -diag_l)

        return mask1 + mask2
class TransformerDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
        final_norm (bool, optional): apply layer norm to the output of the
            final decoder layer (default: True).
    """
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 no_encoder_attn=False,
                 final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        self.output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(
            embed_dim)  # todo: try with input_embed_dim

        self.project_in_dim = Linear(
            input_embed_dim, embed_dim,
            bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args, no_encoder_attn)
            for _ in range(args.decoder_layers)
        ])

        self.adaptive_softmax = None

        self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
            if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                self.output_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.adaptive_softmax_dropout,
                adaptive_inputs=embed_tokens
                if args.tie_adaptive_weights else None,
                factor=args.adaptive_softmax_factor,
                tie_proj=args.tie_adaptive_proj,
            )
        elif not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), self.output_embed_dim))
            nn.init.normal_(self.embed_out,
                            mean=0,
                            std=self.output_embed_dim**-0.5)

        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.decoder_normalize_before and final_norm
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)
        self.onnx_trace = False
        self.decoder_max_order = args.decoder_max_order
        self.clamp_value = getattr(args, 'clamp_value', 0.01)
        self.gs_clamp = args.gs_clamp

    def set_perm_order(self, perm_order=0):
        assert isinstance(perm_order, int) and 0 <= perm_order <= 5
        for layer in self.layers:
            layer.set_perm_order(perm_order)

    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None,
                **unused):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for

                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        x, extra = self.extract_features(prev_output_tokens, encoder_out,
                                         incremental_state)
        x = self.output_layer(x, encoder_out)
        return x, extra

    def extract_features(self,
                         prev_output_tokens,
                         encoder_out=None,
                         incremental_state=None,
                         **unused):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        inner_states = [x]

        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out['encoder_out']
                if encoder_out is not None else None,
                encoder_out['encoder_padding_mask']
                if encoder_out is not None else None,
                incremental_state,
                self_attn_mask=self.buffered_future_mask(x)
                if incremental_state is None else None,
            )
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        return x, {'attn': attn, 'inner_states': inner_states}

    def output_layer(self, features, encoder_out, **kwargs):
        """Project features to the vocabulary size."""
        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                return [
                    F.linear(features, self.embed_tokens.weight),
                    encoder_out['encoder_pred_order']
                ]
            else:
                return F.linear(features, self.embed_out)
        else:
            return features

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(
                self, '_future_mask'
        ) or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
                1)
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'encoder_attn_layer_norm',
                '2': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = '{}.layers.{}.layer_norms.{}.{}'.format(
                        name, i, old, m)
                    if k in state_dict:
                        state_dict['{}.layers.{}.{}.{}'.format(
                            name, i, new, m)] = state_dict[k]
                        del state_dict[k]
        if utils.item(
                state_dict.get('{}.version'.format(name), torch.Tensor(
                    [1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict['{}.version'.format(name)] = torch.Tensor([1])

        return state_dict

    def get_normalized_probs(self,
                             net_output,
                             log_probs,
                             sample,
                             gs_tau=0.5,
                             gs_hard=False):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0][0]
        orders = net_output[0][1]
        if log_probs:
            return (utils.log_softmax(logits,
                                      dim=-1,
                                      onnx_trace=self.onnx_trace),
                    *self.gumbel_softmax(
                        orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1))
        else:
            return (utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace),
                    *self.gumbel_softmax(
                        orders, gs_tau=gs_tau, gs_hard=gs_hard, dim=-1))

    def gumbel_softmax(self, logits, gs_tau=0.5, gs_hard=False, dim=-1):
        if not gs_hard:
            prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
            prob_clamp = torch.clamp(
                prob, self.clamp_value,
                1. - (self.decoder_max_order - 1) * self.clamp_value)
            logprob = torch.log(prob_clamp if self.gs_clamp else prob)
            gs = F.gumbel_softmax(
                logprob,
                tau=gs_tau,
                hard=False,
            )
        else:
            prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
            prob_clamp = torch.clamp(
                prob, self.clamp_value,
                1. - (self.decoder_max_order - 1) * self.clamp_value)
            max_idx = torch.argmax(logits, -1, keepdim=True)
            one_hot = logits.new_zeros(logits.size())
            gs = one_hot.scatter(-1, max_idx, 1)
        return gs, prob, prob_clamp
class HierarchicalTransformerEncoder(FairseqEncoder):
    """
    hierarchical_transformer_sent encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`hierarchical_transformer_sentEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim / 2,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None
        self.embed_positions2 = PositionalEmbedding(
            args.max_source_positions,
            embed_dim / 4,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend(
            [TransformerLayer(args) for _ in range(args.encoder_layers)])

        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)
            self.sentence_norm = LayerNorm(embed_dim)
            self.doc_norm = LayerNorm(embed_dim)

    def forward(self, src_tokens, src_lengths, block_mask, doc_lengths,
                doc_block_mask):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, n_blocks, n_tokens)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            block_mask (torch.LongTensor): block mask of the source sentences of shape
                `(batch, n_blocks, n_blocks)`
            doc_lengths (torch.LongTensor): doc mask of the source sentences of shape
                `(batch)`
            doc_block_mask (torch.LongTensor): doc mask of the source sentences of shape
                `(batch, n_docs, n_blocks)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        batch_size, n_blocks, n_tokens = src_tokens.size()
        doc_padding_mask = torch.arange(0, doc_lengths.max())
        doc_padding_mask = doc_padding_mask.repeat(doc_lengths.numel(), 1)
        doc_padding_mask = 1 - doc_padding_mask.lt(
            doc_lengths.unsqueeze(1).cpu())
        doc_padding_mask = doc_padding_mask.byte().cuda()
        n_docs = doc_padding_mask.size(1)
        x = self.embed_scale * self.embed_tokens(src_tokens)

        # if self.embed_positions is not None:
        local_pos_emb = self.embed_positions(
            src_tokens.view(batch_size * n_blocks, n_tokens))
        local_pos_emb = local_pos_emb.view(batch_size, n_blocks, n_tokens, -1)

        def collate_embedding(values, pad_idx, size):
            """Convert a list of 2d tensors into a padded 3d tensor."""
            # size = max(v.size(0) for v in values)
            res = values[0][0, 0].new(len(values), size,
                                      values[0].size(1)).fill_(pad_idx)

            def copy_tensor(src, dst):
                assert dst.numel() == src.numel()
                dst.copy_(src)

            for i, v in enumerate(values):
                copy_tensor(v[:min(v.size(0), size)], res[i][:v.size(0)])
            return res

        doc_sentence_lengths = torch.sum(doc_block_mask, 2)  # (batch, n_docs)

        block_pos_emb = self.embed_positions2(torch.sum(
            src_tokens, 2))  # (batch, n_blocks, embed_dim)
        block_pos_emb = collate_embedding([
            torch.cat([
                block_pos_emb[i, :doc_sentence_lengths[i, j]]
                for j in range(n_docs) if doc_sentence_lengths[i, j] != 0
            ], 0) for i in range(block_pos_emb.size(0))
        ], 0, n_blocks)  # (batch, n_blocks, embed_dim)

        block_pos_emb = block_pos_emb.unsqueeze(2).repeat(1, 1, n_tokens, 1)

        def collate_embedding(values, pad_idx, size):
            """Convert a list of 2d tensors into a padded 3d tensor."""
            # size = max(v.size(0) for v in values)
            res = values[0][0, 0].new(len(values), size,
                                      values[0].size(1)).fill_(pad_idx)

            def copy_tensor(src, dst):
                assert dst.numel() == src.numel()
                dst.copy_(src)

            for i, v in enumerate(values):
                copy_tensor(v[:min(v.size(0), size)], res[i][:v.size(0)])
            return res

        doc_pos_emb = self.embed_positions2(
            doc_sentence_lengths)  # (batch, n_docs, embed_dim)
        doc_pos_emb = collate_embedding([
            torch.cat([
                doc_pos_emb[i, j].unsqueeze(0).repeat(
                    doc_sentence_lengths[i, j], 1)
                for j in range(doc_pos_emb.size(1))
                if doc_sentence_lengths[i, j] != 0
            ], 0) for i in range(doc_pos_emb.size(0))
        ], 0, n_blocks)  # (batch, n_blocks, embed_dim)
        doc_pos_emb = doc_pos_emb.unsqueeze(2).repeat(1, 1, n_tokens, 1)

        combined_pos_emb = torch.cat(
            [local_pos_emb, block_pos_emb, doc_pos_emb], -1)
        x += combined_pos_emb

        x = F.dropout(x, p=self.dropout, training=self.training)

        # compute padding mask
        local_padding_mask = src_tokens.eq(self.padding_idx).view(
            batch_size * n_blocks, n_tokens)
        block_padding_mask = torch.sum(
            1 - local_padding_mask.view(batch_size, n_blocks, n_tokens),
            -1) == 0

        x = x.view(batch_size * n_blocks, n_tokens, -1)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        block_vec = torch.zeros(n_blocks, batch_size,
                                self.embed_tokens.embedding_dim).cuda()
        doc_vec = torch.zeros(n_docs, batch_size,
                              self.embed_tokens.embedding_dim).cuda()

        # encoder local layers
        for layer in self.layers:
            x, block_vec, doc_vec = layer(x, block_vec, doc_vec,
                                          local_padding_mask,
                                          block_padding_mask, doc_padding_mask,
                                          block_mask, doc_block_mask,
                                          batch_size, n_blocks)

        if self.normalize:
            x = self.layer_norm(x)
            block_vec = self.sentence_norm(block_vec)
            doc_vec = self.doc_norm(doc_vec)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        mask_hier = 1 - local_padding_mask[:, :, None].float()
        src_features = x * mask_hier
        src_features = src_features.view(batch_size, n_blocks * n_tokens, -1)
        src_features = src_features.transpose(
            0, 1).contiguous()  # src_len, batch_size, hidden_dim
        mask_hier = mask_hier.view(batch_size, n_blocks * n_tokens, -1)
        mask_hier = mask_hier.transpose(0, 1).contiguous()

        unpadded = [
            torch.masked_select(src_features[:, i],
                                mask_hier[:, i].byte()).view(
                                    [-1, src_features.size(-1)])
            for i in range(src_features.size(1))
        ]

        max_l = max([p.size(0) for p in unpadded])

        def sequence_mask(lengths, max_len=None):
            """
            Creates a boolean mask from sequence lengths.
            """
            batch_size = lengths.numel()
            max_len = max_len or lengths.max()
            return (torch.arange(0, max_len).type_as(lengths).repeat(
                batch_size, 1).lt(lengths.unsqueeze(1)))

        mask_hier = sequence_mask(torch.tensor([p.size(0) for p in unpadded]),
                                  max_l).cuda()
        mask_hier = 1 - mask_hier[:, None, :]

        unpadded = torch.stack([
            torch.cat([
                p,
                torch.zeros(max_l - p.size(0), src_features.size(-1)).cuda()
            ]) for p in unpadded
        ], 1)

        x = unpadded
        # x = unpadded.transpose(0, 1)
        encoder_padding_mask = mask_hier.squeeze(1)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            'sentence_out': block_vec,  # T x B x C
            'sentence_padding_mask': block_padding_mask,  # B x T
            'doc_out': doc_vec,  # T x B x C
            'doc_padding_mask': doc_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        if encoder_out['sentence_out'] is not None:
            encoder_out['sentence_out'] = \
                encoder_out['sentence_out'].index_select(1, new_order)
        if encoder_out['sentence_padding_mask'] is not None:
            encoder_out['sentence_padding_mask'] = \
                encoder_out['sentence_padding_mask'].index_select(0, new_order)
        if encoder_out['doc_out'] is not None:
            encoder_out['doc_out'] = \
                encoder_out['doc_out'].index_select(1, new_order)
        if encoder_out['doc_padding_mask'] is not None:
            encoder_out['doc_padding_mask'] = \
                encoder_out['doc_padding_mask'].index_select(0, new_order)
        return encoder_out

    def reorder_encoder_input(self, encoder_input, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        # print('reorder')
        if encoder_input['src_tokens'] is not None:
            encoder_input['src_tokens'] = \
                encoder_input['src_tokens'].index_select(0, new_order)
        if encoder_input['src_lengths'] is not None:
            encoder_input['src_lengths'] = \
                encoder_input['src_lengths'].index_select(0, new_order)
        if encoder_input['block_mask'] is not None:
            encoder_input['block_mask'] = \
                encoder_input['block_mask'].index_select(0, new_order)
        if encoder_input['doc_block_mask'] is not None:
            encoder_input['doc_block_mask'] = \
                encoder_input['doc_block_mask'].index_select(0, new_order)
        if encoder_input['doc_lengths'] is not None:
            encoder_input['doc_lengths'] = \
                encoder_input['doc_lengths'].index_select(0, new_order)
        return encoder_input

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(state_dict,
                                                    f"{name}.layers.{i}")

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #17
0
class HybridEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.dropout = args.embed_dropout

        self.embed_tokens = embed_tokens
        self.embed_dim = embed_tokens.embedding_dim
        self.embed_scale = math.sqrt(self.embed_dim)
        self.padding_idx = embed_tokens.padding_idx

        self.max_source_positions = args.max_source_positions
        self.embed_positions = PositionalEmbedding(
            self.max_source_positions,
            self.embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.encoder_layers = args.encoder_layers
        self.convlayers = nn.ModuleList([
            DynamicConvEncoderLayer(args.encoder_embed_dim,
                                    args.encoder_embed_dim,
                                    args.encoder_attention_heads,
                                    args.encoder_kernel_size_list[i],
                                    input_dropout=args.conv_input_dropout,
                                    weight_dropout=args.conv_weight_dropout,
                                    dropout=args.conv_output_dropout)
            for i in range(self.encoder_layers)
        ])
        self.attnlayers = nn.ModuleList([
            AttentionEncoderLayer(args.encoder_embed_dim,
                                  args.encoder_attention_heads,
                                  self_attention=True,
                                  attention_dropout=args.attn_weight_dropout,
                                  dropout=args.attn_output_dropout)
            for _ in range(self.encoder_layers)
        ])
        self.fflayers = nn.ModuleList([
            FFLayer(args.encoder_embed_dim,
                    args.encoder_ffn_embed_dim,
                    relu_dropout=args.ff_relu_dropout,
                    dropout=args.ff_output_dropout)
            for _ in range(self.encoder_layers)
        ])
        self.ratios = nn.Parameter(torch.FloatTensor(self.encoder_layers, 1),
                                   requires_grad=True)
        self.ratios.data.fill_(0.5)
        # self.ratios = [nn.Parameter(torch.FloatTensor(1), requires_grad=True).cuda() for _ in range(7)]
        # for ratio in self.ratios:
        #     ratio.data.fill_(0.5)

        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(self.embed_dim)

    def forward(self, src_tokens, **unused):
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = x.transpose(0, 1)

        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        ### I want to keep the mask anyway
        # if not encoder_padding_mask.any():
        #     encoder_padding_mask = None
        encoder_states = []
        for i in range(self.encoder_layers):
            x1, state1 = self.convlayers[i](
                x, encoder_padding_mask=encoder_padding_mask)
            x2, state2 = self.attnlayers[i](
                x, encoder_padding_mask=encoder_padding_mask)
            if state1 is not None:
                encoder_states.append(state1)
            if state2 is not None:
                encoder_states.append(state2)
            x = x1 * self.ratios[i] + x2 * (1 - self.ratios[i])
            # x = 0.5*x1 + 0.5*x2
            x, _ = self.fflayers[i](x,
                                    encoder_padding_mask=encoder_padding_mask)

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_x': x,
            'encoder_padding_mask': encoder_padding_mask,
            'encoder_lstm_states': self.get_lstm_states(encoder_states)
        }

    def construct_encoder_layer(self, gene):
        if gene['type'] == 'recurrent':
            return LSTMEncoderLayer(**gene['param'])
        elif gene['type'] == 'lightconv':
            return LightConvEncoderLayer(**gene['param'])
        elif gene['type'] == 'dynamicconv':
            return DynamicConvEncoderLayer(**gene['param'])
        elif gene['type'] == 'self-attention':
            return AttentionEncoderLayer(**gene['param'], self_attention=True)
        elif gene['type'] == 'ff':
            return FFLayer(**gene['param'])
        else:
            raise NotImplementedError('Unknown Decoder Gene Type!')

    def get_lstm_states(self, encoder_states):
        # only return the state of the topmost lstm layer
        final_state = None
        for state in encoder_states:
            if state is not None and "lstm_hidden_state" in state.keys():
                final_state = state["lstm_hidden_state"]
        return final_state

    def reorder_encoder_out(self, encoder_out, new_order):
        if encoder_out['encoder_x'] is not None:
            encoder_out['encoder_x'] = \
                encoder_out['encoder_x'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        if encoder_out['encoder_lstm_states'] is not None:
            hiddens, cells = encoder_out['encoder_lstm_states']
            hiddens = hiddens.index_select(1, new_order)
            cells = cells.index_select(1, new_order)
            encoder_out['encoder_lstm_states'] = (hiddens, cells)
        return encoder_out

    def max_positions(self):
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())
Example #18
0
class TaLKConvDecoder(FairseqIncrementalDecoder):
    """
    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs.
            Default: ``False``
    """
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 no_encoder_attn=False,
                 final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(
            embed_dim)  # todo: try with input_embed_dim

        self.project_in_dim = Linear(
            input_embed_dim, embed_dim,
            bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TaLKConvDecoderLayer(args,
                                 no_encoder_attn,
                                 kernel_size=args.decoder_kernel_size_list[i])
            for i in range(args.decoder_layers)
        ])

        self.adaptive_softmax = None

        self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
            if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                output_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.adaptive_softmax_dropout,
                adaptive_inputs=embed_tokens
                if args.tie_adaptive_weights else None,
                factor=args.adaptive_softmax_factor,
                tie_proj=args.tie_adaptive_proj,
            )
        elif not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), output_embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim**-0.5)
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.decoder_normalize_before and final_norm
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

        self.acts_reg = []

    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None,
                **kwargs):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the last decoder layer's output of shape `(batch, tgt_len,
                  vocab)`
                - the last decoder layer's attention weights of shape `(batch,
                  tgt_len, src_len)`
        """
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        inner_states = [x]

        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x, encoder_out['encoder_out'] if encoder_out is not None else
                None, encoder_out['encoder_padding_mask']
                if encoder_out is not None else None, incremental_state)
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            else:
                x = F.linear(x, self.embed_out)

        return x, {'attn': attn, 'inner_states': inner_states}

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(
                self, '_future_mask'
        ) or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
                1)
        return self._future_mask[:dim, :dim]
Example #19
0
class LightConvEncoder(FairseqEncoder):
    """
    LightConv encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`LightConvEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            LightConvEncoderLayer(args,
                                  kernel_size=args.encoder_kernel_size_list[i])
            for i in range(args.encoder_layers)
        ])
        self.encoder_dynamic_combination = args.encoder_dynamic_combination
        self.encoder_linear_combination = args.encoder_linear_combination
        assert not (self.encoder_dynamic_combination
                    and self.encoder_linear_combination)
        if self.encoder_linear_combination or self.encoder_dynamic_combination:
            self.weight_ffn = nn.Sequential(
                nn.Linear(embed_dim, args.encoder_ffn_embed_dim),
                nn.ReLU(),
                nn.Linear(args.encoder_ffn_embed_dim, embed_dim),
            )
        if args.encoder_dynamic_combination:
            self.proj = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(embed_dim * args.encoder_layers, embed_dim * 2),
                    nn.ReLU(),
                    nn.Linear(embed_dim * 2, embed_dim),
                ) for _ in range(args.encoder_layers)
            ])
        if args.encoder_linear_combination:
            self.weights = nn.ParameterList([
                nn.Parameter(torch.randn(1, 1, embed_dim), requires_grad=True)
                for _ in range(args.encoder_layers)
            ])
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

    def forward(self, src_tokens, **unused):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        if self.encoder_dynamic_combination or self.encoder_linear_combination:
            hiddens = []
        else:
            hiddens = None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)
            if self.encoder_dynamic_combination or self.encoder_linear_combination:
                hiddens.append(x)

        if self.encoder_dynamic_combination:
            assert torch.equal(x, hiddens[-1])
            acc_x = torch.zeros_like(x)
            catted_hidden = torch.cat(hiddens.unbind(), -1)
            for i, layer in enumerate(self.proj):
                acc_x += layer(catted_hidden) * hiddens[i]
            x = acc_x + self.weight_ffn(acc_x)

        if self.encoder_linear_combination:
            assert torch.equal(x, hiddens[-1])
            acc_x = torch.zeros_like(x)
            for i, weight in enumerate(self.weights):
                acc_x += weight * hiddens[i]
            x = acc_x + self.weight_ffn(acc_x)

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())
class transformer_with_copyDecoder(FairseqIncrementalDecoder):
    """
    transformer_with_copy decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`transformer_with_copyDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
        final_norm (bool, optional): apply layer norm to the output of the
            final decoder layer (default: True).
    """
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 no_encoder_attn=False,
                 final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(
            embed_dim)  # todo: try with input_embed_dim

        self.project_in_dim = Linear(
            input_embed_dim, embed_dim,
            bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            transformer_with_copyDecoderLayer(args, no_encoder_attn)
            for _ in range(args.decoder_layers)
        ])

        self.copy_attention = MultiheadOnlyAttention(
            embed_dim,
            1,
            dropout=0,
        )
        self.copy_or_generate = nn.Sequential(nn.Linear(embed_dim, 1),
                                              nn.Sigmoid())

        self.adaptive_softmax = None

        self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
            if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                output_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.adaptive_softmax_dropout,
                adaptive_inputs=embed_tokens
                if args.tie_adaptive_weights else None,
                factor=args.adaptive_softmax_factor,
                tie_proj=args.tie_adaptive_proj,
            )
        elif not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), output_embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim**-0.5)
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.decoder_normalize_before and final_norm
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the last decoder layer's output of shape `(batch, tgt_len,
                  vocab)`
                - the last decoder layer's attention weights of shape `(batch,
                  tgt_len, src_len)`
        """
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        inner_states = [x]

        # decoder layers
        for layer in self.layers:
            x, _ = layer(
                x,
                encoder_out['encoder_out']
                if encoder_out is not None else None,
                encoder_out['encoder_padding_mask']
                if encoder_out is not None else None,
                incremental_state,
                self_attn_mask=self.buffered_future_mask(x)
                if incremental_state is None else None,
            )
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        _, copy = self.copy_attention(
            query=x,
            key=encoder_out['encoder_out']
            if encoder_out is not None else None,
            value=encoder_out['encoder_out']
            if encoder_out is not None else None,
            key_padding_mask=encoder_out['encoder_padding_mask']
            if encoder_out is not None else None,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=True,
        )

        copy_or_generate = self.copy_or_generate(x).transpose(0, 1)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            else:
                x = F.linear(x, self.embed_out)

        return x, {
            'attn': copy,
            'inner_states': inner_states,
            'copy_or_generate': copy_or_generate
        }

    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""
        # print('enter normalized.')
        if 'net_input' in sample.keys():
            enc_seq_ids = sample['net_input']['src_tokens']
        else:
            enc_seq_ids = sample['src_tokens']

        # wvocab_size = net_output[0].size(2)
        # batch_size = enc_seq_ids.size(0)
        # seq_len = enc_seq_ids.size(1)
        # one_hot = torch.zeros(batch_size, seq_len, wvocab_size).cuda().scatter_(dim=2, index=enc_seq_ids.unsqueeze(-1), value=1)
        #
        # copy_probs = torch.matmul(net_output[1]['attn'], one_hot)

        # final_dist = vocab_dist.scatter_add(1, encoder_batch_extend_vocab, attn_dist)

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        if log_probs:
            generate = utils.softmax(
                logits, dim=-1,
                onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate']
            copy = net_output[1]['attn'] * (1 -
                                            net_output[1]['copy_or_generate'])
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['attn'].size(1), 1)
            final = generate.scatter_add(2, enc_seq_ids, copy)
            final = torch.log(final + 1e-15)
            return final
        else:
            generate = utils.log_softmax(
                logits, dim=-1,
                onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate']
            copy = net_output[1]['attn'] * (1 -
                                            net_output[1]['copy_or_generate'])
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['attn'].size(1), 1)
            final = generate.scatter_add(2, enc_seq_ids, copy)
            return final

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(
                self, '_future_mask'
        ) or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
                1)
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'encoder_attn_layer_norm',
                '2': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = '{}.layers.{}.layer_norms.{}.{}'.format(
                        name, i, old, m)
                    if k in state_dict:
                        state_dict['{}.layers.{}.{}.{}'.format(
                            name, i, new, m)] = state_dict[k]
                        del state_dict[k]
        if utils.item(
                state_dict.get('{}.version'.format(name), torch.Tensor(
                    [1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict['{}.version'.format(name)] = torch.Tensor([1])

        return state_dict
Example #21
0
File: bgt.py Project: jwcmu/bgt
class TransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))
        self.args = args
        self.dropout = args.dropout
        self.bgt_setting = self.args.bgt_setting

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args) for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

        self.hidden2mean = nn.Linear(embed_dim,
                                     self.args.latent_size,
                                     bias=False)

        if self.bgt_setting == "bgt":
            self.hidden2logv = nn.Linear(embed_dim,
                                         self.args.latent_size,
                                         bias=False)
            self.latent2hidden = nn.Linear(self.args.latent_size,
                                           embed_dim,
                                           bias=False)

    def forward(self, src_tokens, src_lengths, generate=False):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        # if not encoder_padding_mask.any():
        #    encoder_padding_mask = None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm:
            x = self.layer_norm(x)

        #sample z
        z = None
        if self.bgt_setting == "bgt" and not generate:
            z = torch.randn([x.size()[1], self.args.latent_size])

        sent_emb, mean, logv = self.get_sentence_embs(x, encoder_padding_mask,
                                                      z)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T,
            'sent_emb': sent_emb,
            'mean': mean,
            'logv': logv,
            'z': z,
        }

    def get_sentence_embs(self, encoder_out, encoder_padding_mask, z=None):

        if not self.args.cpu:
            mean_pool = torch.where(
                encoder_padding_mask.unsqueeze(2).cuda(),
                torch.Tensor([float(0)]).cuda(),
                encoder_out.transpose(1, 0).float()).type_as(encoder_out)
        else:
            mean_pool = torch.where(encoder_padding_mask.unsqueeze(2),
                                    torch.Tensor([float(0)]),
                                    encoder_out.transpose(
                                        1, 0).float()).type_as(encoder_out)

        den = encoder_padding_mask.size()[1] - encoder_padding_mask.sum(dim=1)
        mean_pool = mean_pool.sum(dim=1) / den.float().unsqueeze(1)

        mean = self.hidden2mean(mean_pool)
        logv = None
        if self.bgt_setting == "bgt":
            logv = self.hidden2logv(mean_pool)
            if z is not None:
                std = torch.exp(0.5 * logv)
                if not self.args.cpu:
                    z = z.cuda()
                z = z * std + mean
                sent_emb = self.latent2hidden(z)
            else:
                sent_emb = self.latent2hidden(mean)
        else:
            sent_emb = mean

        return sent_emb, mean, logv

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        if encoder_out['sent_emb'] is not None:
            encoder_out['sent_emb'] = \
                encoder_out['sent_emb'].index_select(0, new_order)
        if encoder_out['mean'] is not None:
            encoder_out['mean'] = \
                encoder_out['mean'].index_select(0, new_order)
        if encoder_out['logv'] is not None:
            encoder_out['logv'] = \
                encoder_out['logv'].index_select(0, new_order)
        if encoder_out['z'] is not None:
            encoder_out['z'] = \
                encoder_out['z'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #22
0
class JointAttentionDecoder(FairseqIncrementalDecoder):
    """
    JointAttention decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`ProtectedTransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        left_pad (bool, optional): whether the input is left-padded. Default:
            ``False``
    """
    def __init__(self,
                 args,
                 dictionary,
                 embed_tokens,
                 left_pad=False,
                 final_norm=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed
        self.kernel_size_list = args.kernel_size_list

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)

        self.project_in_dim = Linear(
            input_embed_dim, embed_dim,
            bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.embed_language = LanguageEmbedding(
            embed_dim) if args.language_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            ProtectedTransformerDecoderLayer(args, no_encoder_attn=True)
            for _ in range(args.decoder_layers)
        ])

        self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
            if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None

        if not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), output_embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim**-0.5)
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.decoder_normalize_before and final_norm
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        """
        Args:
            input (dict): with
                prev_output_tokens (LongTensor): previous decoder outputs of shape
                    `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the last decoder layer's output of shape `(batch, tgt_len,
                  vocab)`
                - the last decoder layer's attention weights of shape `(batch,
                  tgt_len, src_len)`
        """
        tgt_len = prev_output_tokens.size(1)

        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions

        # language embedding
        if self.embed_language is not None:
            lang_emb = self.embed_scale * self.embed_language.view(1, 1, -1)
            x += lang_emb

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None
        inner_states = [x]
        source = encoder_out['encoder_out']
        process_source = incremental_state is None or len(
            incremental_state) == 0

        # extended padding mask
        source_padding_mask = encoder_out['encoder_padding_mask']
        if source_padding_mask is not None:
            target_padding_mask = source_padding_mask.new_zeros(
                (source_padding_mask.size(0), tgt_len))
            self_attn_padding_mask = torch.cat(
                (source_padding_mask, target_padding_mask), dim=1)
        else:
            self_attn_padding_mask = None

        # transformer layers
        for i, layer in enumerate(self.layers):

            if self.kernel_size_list is not None:
                target_mask = self.local_mask(x,
                                              self.kernel_size_list[i],
                                              causal=True,
                                              tgt_len=tgt_len)
            elif incremental_state is None:
                target_mask = self.buffered_future_mask(x)
            else:
                target_mask = None

            if target_mask is not None:
                zero_mask = target_mask.new_zeros(
                    (target_mask.size(0), source.size(0)))
                self_attn_mask = torch.cat((zero_mask, target_mask), dim=1)
            else:
                self_attn_mask = None

            state = incremental_state
            if process_source:
                if state is None:
                    state = {}
                if self.kernel_size_list is not None:
                    source_mask = self.local_mask(source,
                                                  self.kernel_size_list[i],
                                                  causal=False)
                else:
                    source_mask = None
                source, attn = layer(
                    source,
                    None,
                    None,
                    state,
                    self_attn_mask=source_mask,
                    self_attn_padding_mask=source_padding_mask)
                inner_states.append(source)

            x, attn = layer(x,
                            None,
                            None,
                            state,
                            self_attn_mask=self_attn_mask,
                            self_attn_padding_mask=self_attn_padding_mask)
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        # project back to size of vocabulary
        if self.share_input_output_embed:
            x = F.linear(x, self.embed_tokens.weight)
        else:
            x = F.linear(x, self.embed_out)

        pred = x
        info = {'attn': attn, 'inner_states': inner_states}

        return pred, info

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        """Cached future mask."""
        dim = tensor.size(0)
        #pylint: disable=access-member-before-definition, attribute-defined-outside-init
        if not hasattr(
                self, '_future_mask'
        ) or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
                1)
        return self._future_mask[:dim, :dim]

    def local_mask(self, tensor, kernel_size, causal, tgt_len=None):
        """Locality constraint mask."""
        rows = tensor.size(0)
        cols = tensor.size(0) if tgt_len is None else tgt_len
        if causal:
            if rows == 1:
                mask = utils.fill_with_neg_inf(tensor.new(1, cols))
                mask[0, -kernel_size:] = 0
                return mask
            else:
                diag_u, diag_l = 1, kernel_size
        else:
            diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 \
                else (kernel_size // 2, kernel_size // 2 + 1)
        mask1 = torch.triu(utils.fill_with_neg_inf(tensor.new(rows, cols)),
                           diag_u)
        mask2 = torch.tril(utils.fill_with_neg_inf(tensor.new(rows, cols)),
                           -diag_l)

        return mask1 + mask2
class TransformerDecoderPerm(FairseqIncrementalDecoder):
    """Transformer decoder."""
    def __init__(self, args, dictionary, embed_tokens, left_pad=False):
        super().__init__(dictionary)
        if not isinstance(args.shorten_decoder_perm, bool):
            args.shorten_decoder_perm = eval(args.shorten_decoder_perm)
        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        self.embed_dim = embed_dim = embed_tokens.embedding_dim
        self.padding_idx = padding_idx = embed_tokens.padding_idx

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            1024,
            embed_dim,
            padding_idx,
            learned=args.decoder_learned_pos,
        )

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderPermLayer(args)
            for i in range(args.decoder_perm_layers)
        ])

        self.sentence_transformer_arch = args.sentence_transformer_arch
        self.predict_arch = args.predict_arch
        self.pointer_net_attn_type = args.pointer_net_attn_type

        if not self.share_input_output_embed and self.predict_arch == 'seq2seq':
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=embed_dim**-0.5)

        if self.predict_arch == 'pointer_net':
            if self.pointer_net_attn_type == 'perceptron':
                self.pointer_encoder_embed_weight = nn.Parameter(
                    torch.Tensor(embed_dim, embed_dim))
                self.pointer_decoder_embed_weight = nn.Parameter(
                    torch.Tensor(embed_dim, embed_dim))
                self.mapping_vector = nn.Parameter(torch.Tensor(1, embed_dim))
                nn.init.normal_(self.pointer_encoder_embed_weight,
                                mean=0,
                                std=embed_dim**-0.5)
                nn.init.normal_(self.pointer_decoder_embed_weight,
                                mean=0,
                                std=embed_dim**-0.5)
                nn.init.normal_(self.mapping_vector,
                                mean=0,
                                std=embed_dim**-0.5)
            elif self.pointer_net_attn_type == 'general':
                self.pointer_attn_weight = nn.Parameter(
                    torch.Tensor(args.decoder_embed_dim,
                                 args.encoder_embed_dim))
                nn.init.normal_(self.pointer_attn_weight,
                                mean=0,
                                std=embed_dim**-0.5)
            elif self.pointer_net_attn_type == 'dot':
                pass
            else:
                raise RuntimeError(
                    "pointer-net-attn-type doesn't support {} yet !".format(
                        self.pointer_net_attn_type))

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if (not hasattr(self, '_future_mask') or self._future_mask is None
                or self._future_mask.device != tensor.device
                or self._future_mask.size(0) < dim):
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        )

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)
        x += positions

        # add the sent embedding to x
        prev_output_tokens_temp = prev_output_tokens.masked_fill(
            prev_output_tokens == self.padding_idx, 0)
        sents_embedding = torch.stack([
            encoder_out['encoder_out'][i, prev_output_tokens_temp[:, 1:][i]]
            for i in range(prev_output_tokens_temp.shape[0])
        ])
        sents_embedding[prev_output_tokens[:, 1:] ==
                        self.padding_idx] = self.embed_tokens(
                            torch.LongTensor([self.padding_idx]).to(x.device))
        x[:, 1:] += sents_embedding
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        encoder_out_embedding = encoder_out['encoder_out'].transpose(
            0, 1) if self.sentence_transformer_arch == 'bert' else encoder_out[
                'encoder_out']
        # decoder layers
        self_attn_mask = self.buffered_future_mask(x)
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out_embedding,
                encoder_out['encoder_padding_mask'],
                incremental_state,
                self_attn_mask,
            )

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        # project back to size of vocabulary
        if self.predict_arch == 'seq2seq':
            if self.share_input_output_embed:
                out = F.linear(x, self.embed_tokens.weight)
            else:
                out = F.linear(x, self.embed_out)
        elif self.predict_arch == 'pointer_net':
            bsz = prev_output_tokens.shape[0]
            encoder_embedding_querry = torch.cat([
                encoder_out['encoder_out'],
                self.embed_tokens(
                    torch.LongTensor([self.dictionary.eos()]).to(
                        x.device)).expand([bsz, 1, self.embed_dim])
            ],
                                                 dim=1)
            if self.pointer_net_attn_type == 'perceptron':
                temp_embedding = F.linear(
                    encoder_embedding_querry, self.pointer_encoder_embed_weight
                ).unsqueeze(dim=1) + F.linear(
                    x, self.pointer_decoder_embed_weight).unsqueeze(dim=2)
                temp_embedding = F.tanh(temp_embedding)
                out = F.linear(temp_embedding,
                               self.mapping_vector).squeeze(dim=-1)
            elif self.pointer_net_attn_type == 'general':
                out = x.matmul(self.pointer_attn_weight).bmm(
                    encoder_embedding_querry.transpose(-1, -2))
            elif self.pointer_net_attn_type == 'dot':
                out = x.bmm(encoder_embedding_querry.transpose(-1, -2))
        return out

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return self.embed_positions.max_positions()

    def upgrade_state_dict(self, state_dict):
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            if 'decoder_perm.embed_positions.weights' in state_dict:
                del state_dict['decoder_perm.embed_positions.weights']
            # if 'decoder_perm.embed_positions._float_tensor' in state_dict:
            #     del state_dict['decoder_perm.embed_positions._float_tensor']
            state_dict[
                'decoder_perm.embed_positions._float_tensor'] = torch.FloatTensor(
                    1)
        '''
        in_proj_weight -> q_proj.weight, k_proj.weight, v_proj.weight
        in_proj_bias -> q_proj.bias, k_proj.bias, v_proj.bias
        '''
        def transform_params(idx, suffix):
            in_proj_ = state_dict[
                'decoder_perm.layers.{}.self_attn.in_proj_{}'.format(
                    idx, suffix)]
            del state_dict[
                'decoder_perm.layers.{}.self_attn.in_proj_{}'.format(
                    idx, suffix)]
            state_dict['decoder_perm.layers.{}.self_attn.q_proj.{}'.format(idx, suffix)], state_dict['decoder_perm.layers.{}.self_attn.k_proj.{}'.format(idx, suffix)],\
            state_dict['decoder_perm.layers.{}.self_attn.v_proj.{}'.format(idx, suffix)] = in_proj_.chunk(3, dim=0)

        if 'decoder_perm.layers.0.self_attn.in_proj_weight' in state_dict:
            for idx in range(len(self.layers)):
                transform_params(idx, 'weight')

        if 'decoder_perm.layers.0.self_attn.in_proj_bias' in state_dict:
            for idx in range(len(self.layers)):
                transform_params(idx, 'bias')

        return state_dict
Example #24
0
class TransformerEncoder(nn.Module):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, vocab, embed_tokens):
        super().__init__()
        self.vocab = vocab

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        )

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args, i)
            for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

    def forward(self,
                src_tokens,
                data_holder=None,
                mask=None,
                encoder_mode="soft",
                encoder_temperature=-1,
                need_weights=False,
                **unused):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        embedding = self.embed_tokens(src_tokens)
        if data_holder is not None:
            if data_holder.permute_embed is not None:
                bsz, _, dim = embedding.shape
                for i in range(bsz):
                    perm = torch.randperm(dim)
                    embedding[i, data_holder.permute_embed] = embedding[
                        i, data_holder.permute_embed, perm]

            if data_holder.keep_grads:
                data_holder.embedding = embedding
                data_holder.embedding.retain_grad()

        x = self.embed_scale * embedding
        x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        attn_data_list = []
        for layer in self.layers:
            x, attn_data = layer(x,
                                 data_holder=data_holder,
                                 encoder_padding_mask=encoder_padding_mask,
                                 attn_mask=mask,
                                 encoder_mode=encoder_mode,
                                 encoder_temperature=encoder_temperature,
                                 need_weights=need_weights)
            attn_data_list.append(attn_data)

        if self.layer_norm:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            'encoder_attn_data_list': attn_data_list
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())
Example #25
0
class TransformerYmaskEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args) for i in range(args.encoder_layers)
        ])

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

        self.embed_lengths = nn.Embedding(args.max_target_positions, embed_dim)
        nn.init.normal_(self.embed_lengths.weight, mean=0, std=0.02)

    def forward(self, src_tokens, src_lengths, **unused):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """

        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)

        # add length prediction part
        len_tokens = self.embed_lengths(
            src_tokens.new(src_tokens.size(0), 1).fill_(0))
        x = torch.cat([len_tokens, x], dim=1)

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        # to keep consistent with x
        encoder_padding_mask = torch.cat([
            encoder_padding_mask.new(src_tokens.size(0), 1).fill_(0),
            encoder_padding_mask
        ],
                                         dim=1)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.layer_norm:
            x = self.layer_norm(x)

        predicted_lengths_logits = torch.matmul(
            x[0, :, :], self.embed_lengths.weight.transpose(0, 1)).float()
        predicted_lengths_logits[:, 0] += float(
            '-inf')  # Cannot predict the len_token
        predicted_lengths = F.log_softmax(predicted_lengths_logits, dim=-1)
        x = x[1:, :, :]
        if encoder_padding_mask is not None:
            encoder_padding_mask = encoder_padding_mask[:, 1:]

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            'predicted_lengths': predicted_lengths,  # B x L
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        if encoder_out['predicted_lengths'] is not None:
            encoder_out['predicted_lengths'] = \
                encoder_out['predicted_lengths'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #26
0
class TransformerDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout
        self.decoder_layerdrop = args.decoder_layerdrop
        self.share_input_output_embed = args.share_decoder_input_output_embed

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        self.output_embed_dim = args.decoder_output_dim

        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens

        self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(
            embed_dim)

        self.project_in_dim = Linear(
            input_embed_dim, embed_dim,
            bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            self.padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)
        self.layer_wise_attention = getattr(args, 'layer_wise_attention',
                                            False)

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args, no_encoder_attn, layer_id=i)
            for i in range(args.decoder_layers)
        ])

        self.adaptive_softmax = None

        self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
            if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                self.output_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.adaptive_softmax_dropout,
                adaptive_inputs=embed_tokens
                if args.tie_adaptive_weights else None,
                factor=args.adaptive_softmax_factor,
                tie_proj=args.tie_adaptive_proj,
            )
        elif not self.share_input_output_embed:
            self.embed_out = nn.Parameter(
                torch.Tensor(len(dictionary), self.output_embed_dim))
            nn.init.normal_(self.embed_out,
                            mean=0,
                            std=self.output_embed_dim**-0.5)

        if args.decoder_normalize_before and not getattr(
                args, 'no_decoder_final_norm', False):
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None
        if getattr(args, 'layernorm_embedding', False):
            self.layernorm_embedding = LayerNorm(embed_dim)
        else:
            self.layernorm_embedding = None

    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None,
                features_only=False,
                **extra_args):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`
            features_only (bool, optional): only return features without
                applying output layer (default: False).

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        x, extra = self.extract_features(prev_output_tokens,
                                         encoder_out=encoder_out,
                                         incremental_state=incremental_state,
                                         **extra_args)
        if not features_only:
            x = self.output_layer(x)
        return x, extra

    def extract_features(
        self,
        prev_output_tokens,
        encoder_out=None,
        incremental_state=None,
        full_context_alignment=False,
        alignment_layer=None,
        alignment_heads=None,
        **unused,
    ):
        """
        Similar to *forward* but only return features.

        Includes several features from "Jointly Learning to Align and
        Translate with Transformer Models" (Garg et al., EMNLP 2019).

        Args:
            full_context_alignment (bool, optional): don't apply
                auto-regressive mask to self-attention (default: False).
            alignment_layer (int, optional): return mean alignment over
                heads at this layer (default: last layer).
            alignment_heads (int, optional): only average alignment over
                this many heads (default: all heads).

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        if alignment_layer is None:
            alignment_layer = len(self.layers) - 1

        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions

        if self.layernorm_embedding:
            x = self.layernorm_embedding(x)

        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        self_attn_padding_mask = None
        if self.cross_self_attention or prev_output_tokens.eq(
                self.padding_idx).any():
            self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)

        # decoder layers
        attn = None
        inner_states = [x]
        for idx, layer in enumerate(self.layers):
            encoder_state = None
            if encoder_out is not None:
                if self.layer_wise_attention:
                    encoder_state = encoder_out.encoder_states[idx]
                else:
                    encoder_state = encoder_out.encoder_out

            if incremental_state is None and not full_context_alignment:
                self_attn_mask = self.buffered_future_mask(x)
            else:
                self_attn_mask = None

            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if not self.training or (dropout_probability >
                                     self.decoder_layerdrop):
                x, layer_attn = layer(
                    x,
                    encoder_state,
                    encoder_out.encoder_padding_mask
                    if encoder_out is not None else None,
                    incremental_state,
                    self_attn_mask=self_attn_mask,
                    self_attn_padding_mask=self_attn_padding_mask,
                    need_attn=(idx == alignment_layer),
                    need_head_weights=(idx == alignment_layer),
                )
                inner_states.append(x)
                if layer_attn is not None and idx == alignment_layer:
                    attn = layer_attn.float()

        if attn is not None:
            if alignment_heads is not None:
                attn = attn[:alignment_heads]

            # average probabilities over heads
            attn = attn.mean(dim=0)

        if self.layer_norm:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        return x, {'attn': attn, 'inner_states': inner_states}

    def output_layer(self, features, **kwargs):
        """Project features to the vocabulary size."""
        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                return F.linear(features, self.embed_tokens.weight)
            else:
                return F.linear(features, self.embed_out)
        else:
            return features

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if (not hasattr(self, '_future_mask') or self._future_mask is None
                or self._future_mask.device != tensor.device
                or self._future_mask.size(0) < dim):
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'encoder_attn_layer_norm',
                '2': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = '{}.layers.{}.layer_norms.{}.{}'.format(
                        name, i, old, m)
                    if k in state_dict:
                        state_dict['{}.layers.{}.{}.{}'.format(
                            name, i, new, m)] = state_dict[k]
                        del state_dict[k]

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])

        return state_dict
Example #27
0
class GraphTransformerEncoder(FairseqEncoder):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): encoding dictionary
        embed_tokens (torch.nn.Embedding): input embedding
    """
    def __init__(self, args, dictionary, embed_tokens, embed_edges):
        super().__init__(dictionary)
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.embed_dim = embed_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_edges = embed_edges
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.l1_gate = nn.Sequential(nn.Linear(embed_dim, embed_dim),
                                     nn.SELU(), nn.Dropout(self.dropout))
        self.l2_gate = nn.Sequential(nn.Linear(embed_dim, embed_dim),
                                     nn.SELU(), nn.Dropout(self.dropout))

        # self.e1_gate = nn.Linear(embed_dim*3, embed_dim)
        # self.e2_gate = nn.Linear(embed_dim * 3, embed_dim)

        self.e1_gate = nn.Sequential(nn.Linear(embed_dim * 3, embed_dim),
                                     nn.SELU(), nn.Dropout(self.dropout))
        self.e2_gate = nn.Sequential(nn.Linear(embed_dim * 3, embed_dim),
                                     nn.SELU(), nn.Dropout(self.dropout))

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args) for i in range(args.encoder_layers)
        ])

        self.graph_layers = nn.ModuleList([])
        self.graph_layers.extend([
            GraphTransformerEncoderLayer(args)
            for i in range(args.graph_layers)
        ])

        self.rnn = nn.LSTM(input_size=embed_dim,
                           hidden_size=embed_dim,
                           batch_first=False,
                           num_layers=1,
                           bidirectional=True)

        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)
            self.graph_layer_norm = LayerNorm(embed_dim)

    def forward(self, src_tokens, src_lengths, enc_edge_ids, enc_edge_links1,
                enc_edge_links2, graph_mask, graph_mask_rev):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None
        enc_edge_padding_mask = enc_edge_ids.eq(self.padding_idx)

        # embed edges
        enc_edge_emb = self.embed_scale * self.embed_edges(enc_edge_ids)
        enc_edge_emb = F.dropout(enc_edge_emb,
                                 p=self.dropout,
                                 training=self.training)

        # B x L x C -> L x B x C

        enc_edge_emb = enc_edge_emb.transpose(0, 1)

        # B x L -> L x B x C
        idx_pairs1 = enc_edge_links1.unsqueeze(-1).repeat(1, 1, self.embed_dim)
        idx_pairs2 = enc_edge_links2.unsqueeze(-1).repeat(1, 1, self.embed_dim)
        idx_pairs1 = idx_pairs1.transpose(0, 1)
        idx_pairs2 = idx_pairs2.transpose(0, 1)

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)

        if self.normalize:
            x = self.layer_norm(x)

        # graph layers
        enc_states = []
        enc_states.append(x)

        for layer in self.graph_layers:

            x = layer(x, self.l1_gate, self.l2_gate, self.e1_gate,
                      self.e2_gate, idx_pairs1, idx_pairs2,
                      encoder_padding_mask, enc_edge_padding_mask, graph_mask,
                      graph_mask_rev, enc_edge_emb)
            enc_states.append(x)

        enc_states = torch.stack(enc_states, dim=3)
        enc_states = torch.transpose(torch.transpose(enc_states, 2, 3), 0, 2)

        bsz = enc_states.size(1)
        seq_len = enc_states.size(2)

        enc_states = enc_states.reshape(
            len(self.graph_layers) + 1, bsz * seq_len, self.embed_dim)

        outputs, state = self.rnn(enc_states)

        x = state[0][::2] + state[1][::2]

        x = x.reshape(bsz, -1, self.embed_dim)
        x = torch.transpose(x, 0, 1)

        if self.normalize:
            x = self.graph_layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        """
        Reorder encoder output according to *new_order*.

        Args:
            encoder_out: output from the ``forward()`` method
            new_order (LongTensor): desired order

        Returns:
            *encoder_out* rearranged according to *new_order*
        """
        if encoder_out['encoder_out'] is not None:
            encoder_out['encoder_out'] = \
                encoder_out['encoder_out'].index_select(1, new_order)
        if encoder_out['encoder_padding_mask'] is not None:
            encoder_out['encoder_padding_mask'] = \
                encoder_out['encoder_padding_mask'].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.embed_positions is None:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(state_dict,
                                                    f"{name}.layers.{i}")

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Example #28
0
class TransformerDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = args.decoder_embed_dim
        self.output_embed_dim = args.decoder_output_dim

        padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)  # todo: try with input_embed_dim

        self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None

        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, embed_dim, padding_idx,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args, no_encoder_attn)
            for _ in range(args.decoder_layers)
        ])

        self.adaptive_softmax = None

        self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
            if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None

        if args.adaptive_softmax_cutoff is not None:
            self.adaptive_softmax = AdaptiveSoftmax(
                len(dictionary),
                self.output_embed_dim,
                options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
                dropout=args.adaptive_softmax_dropout,
                adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
                factor=args.adaptive_softmax_factor,
                tie_proj=args.tie_adaptive_proj,
            )
        elif not self.share_input_output_embed:
            self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)

        if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False):
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None
        #----------------------------
        self.save_attn = args.save_attn
        self.save_attn_path = args.save_attn_path

    def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state)
        x = self.output_layer(x)
        return x, extra

    def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        inner_states = [x]

        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out['encoder_out'] if encoder_out is not None else None,
                encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
                incremental_state,
                self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
            )
            inner_states.append(x)

        if self.layer_norm:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)
        #---------------------------------------------------
        if self.save_attn == True:
            save_attn(attn,self.save_attn_path,self.dictionary.string(prev_output_tokens),
                encoder_out['src_tokens'])
            
        return x, {'attn': attn, 'inner_states': inner_states}

    def output_layer(self, features, **kwargs):
        """Project features to the vocabulary size."""
        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                return F.linear(features, self.embed_tokens.weight)
            else:
                return F.linear(features, self.embed_out)
        else:
            return features

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions, self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
        return self._future_mask[:dim, :dim]

    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'encoder_attn_layer_norm',
                '2': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
                    if k in state_dict:
                        state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
                        del state_dict[k]

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])

        return state_dict
class SentTransformerDecoder(nn.Module):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
        final_norm (bool, optional): apply layer norm to the output of the
            final decoder layer (default: True).
    """
    def __init__(self, args, no_encoder_attn=False, final_norm=True):
        super(SentTransformerDecoder, self).__init__()
        self.dropout = args.dropout
        embed_dim = args.decoder_embed_dim
        self.max_target_positions = args.max_target_positions
        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            padding_idx=0,
            learned=args.decoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerDecoderLayer(args=args, no_encoder_attn=no_encoder_attn)
            for i in range(args.decoder_layers)
        ])
        self.normalize = args.decoder_normalize_before and final_norm
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)

    def forward(self,
                prev_output_tokens,
                prev_output_rep,
                encoder_out=None,
                incremental_state=None,
                **unused):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for input feeding/teacher forcing
            prev_output_rep
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, dim)`
                - a dictionary with any model-specific outputs
        """
        x, extra = self.extract_features(prev_output_tokens, prev_output_rep,
                                         encoder_out, incremental_state)
        # x = self.output_layer(x)
        return x, extra

    def extract_features(self,
                         prev_output_tokens,
                         prev_output_rep,
                         encoder_out=None,
                         incremental_state=None):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        # embed positions
        # (batch, tgt_len)
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            # incre decoding 时就取最后一个 前面的都已经缓存了
            prev_output_tokens = prev_output_tokens[:, -1:]
            prev_output_rep = prev_output_rep[:, -1, :]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        # x = self.embed_scale * self.embed_tokens(prev_output_tokens)
        x = prev_output_rep
        # if self.project_in_dim is not None:
        #     x = self.project_in_dim(x)

        if positions is not None:
            # print('len x %s |len positions %s'%(x.size(),positions.size()))
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        inner_states = [x]

        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out['encoder_out']
                if encoder_out is not None else None,
                encoder_out['encoder_padding_mask']
                if encoder_out is not None else None,
                incremental_state,
                self_attn_mask=self.buffered_future_mask(x)
                if incremental_state is None else None,
            )
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)
        return x, {'attn': attn, 'inner_states': inner_states}

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.embed_positions is None:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(
                self, '_future_mask'
        ) or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
                1)
        return self._future_mask[:dim, :dim]