예제 #1
0
class BSpanDecoder(nn.Module):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, reader_, params, dropout=0.5, embedding=None):
        """
        Args:
            ntoken: vocab size
            ninp: embedding dimension
            nhead: number of heads
            nhid: hidden layer size
            nlayers: number of layers
            reader: instance of `Reader`
            dropout: dropout rate
        """
        super().__init__()
        from torch.nn import TransformerDecoder, TransformerDecoderLayer
        self.model_type = 'TransformerDecoder'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, ninp) if embedding is None else embedding
        self.ninp = ninp
        self.linear = nn.Linear(ninp, ntoken)
        self.reader_ = reader_
        self.params = params

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        # self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def train(self, t=True):
        self.transformer_decoder.train(t)

    def _generate_square_subsequent_mask(self, sz):
        """ This makes the model autoregressive.
        When decoding position t, look only at positions 0...t-1 """
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, tgt, memory):
        """ Call decoder
        the decoder should be called repeatedly

        Args:
            tgt: input to transformer_decoder, shape: (seq, batch)
            memory: output from the encoder

        Returns:
            output from linear layer, (vocab size), pre softmax

        """
        tgt = tgt.long()
        go_tokens = torch.zeros((1, tgt.size(1)), dtype=tgt.dtype) + 3  # GO_2 token has index 3

        tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis


        mask = tgt.eq(0).transpose(0,1)  # 0 corresponds to <pad>
        tgt = self.embedding(tgt) * self.ninp
        tgt = self.pos_encoder(tgt)
        tgt_mask = self._generate_square_subsequent_mask(tgt.size(0))
        output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask)
        output = self.linear(output)
        return output
예제 #2
0
class ResponseDecoder(nn.Module):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, reader_, params, dropout=0.5, embedding=None):
        """
        Args:
            ntoken: vocab size
            ninp: embedding dimension
            nhead: number of heads
            nhid: hidden layer size
            nlayers: number of layers
            reader: instance of `Reader`
            dropout: dropout rate
        """
        super().__init__()
        from torch.nn import TransformerDecoder, TransformerDecoderLayer
        self.model_type = 'TransformerDecoder'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, ninp) if embedding is None else embedding
        self.ninp = ninp
        self.linear = nn.Linear(ninp, ntoken)
        self.reader_ = reader_
        self.params = params

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        # self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def train(self, t=True):
        self.transformer_decoder.train(t)

    def _generate_square_subsequent_mask(self, sz, bspan_size):
        # we do not mask the first positions (1 for degree, 1 for <go> token and 'some' for bspan)
        bspan_size = self.params['bspan_size']
        mask = (torch.triu(torch.ones(sz+1, sz+1), diagonal=-(bspan_size+1)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, tgt, memory, bspan, degree):
        """ Call decoder

        Args:
            tgt: input to transformer_decoder, shape: (seq_len, batch)
            memory: output from the encoder
            degree: degree is the 'output from database', shape: (batch, cfg.degree_size)

        Returns:
            output from linear layer, (vocab size), pre softmax

        """

        go_tokens = torch.ones((1, tgt.size(1)), dtype=tgt.dtype)  # GO token has index 1
        degree_reshaped = torch.zeros((1, tgt.size(1), cfg.embedding_size), dtype=torch.float32)
        # print('tgt.shape0')
        # print(tgt.shape)
        # print('bspan.shape0')
        # print(bspan.shape)
        # print('degree_ershaped.shape0')
        # print(degree_reshaped.shape)
        # print('go_tokens.shape0')
        # print(go_tokens.shape)

        tgt = torch.cat([bspan, go_tokens, tgt], dim=0)  # concat bspan, GO and tokenstoken along sequence length axis
        # TODO pad `tgt` but also think of `degree` which is added later
        # print('tgt.shape')
        # print(tgt.shape)


        mask = torch.cat([torch.ones((1, tgt.size(1)), dtype=torch.int64), tgt]).eq(0).transpose(0,1)  # 0 corresponds to <pad>
        # TODO dimension are wrong
        # TODO also, final tgt dimension should be cfg.max_ts (128). however, now it is 128 before bspan is concatednated with it
        # mask = torch.cat([torch.ones((mask.size(0), 1)).bool(), mask])
        tgt = self.embedding(tgt) * self.ninp
        tgt = self.pos_encoder(tgt)
        # print('tgt.shape2')
        # print(tgt.shape)

        #    eg. [cheap restaurant EOS_Z1 EOS_Z2 PAD2 .... PAD2 01000 GO1 mask mask mask ..... ]
        #    ... [            ...          bspan    ... padding degree go     ....     masking ]


        # A BIG TODO: the size of `tgt` has to take the size of `bspan` (+1+1 for degree, go)  into account
        bspan_size = self.params['bspan_size']  # always the same
        tgt_mask = self._generate_square_subsequent_mask(tgt.size(0), bspan_size)

        # tgt.size(1) is batch size (I know, why dim=1, but nn.Transformer wants it that way)
        degree_reshaped[0, :, :cfg.degree_size] = degree.transpose(0,1)  # add 1 more timestep (the first one as one-hot degree)
        tgt = torch.cat([degree_reshaped, tgt], dim=0)  # concat along sequence lenght axis
        # print('tgt.shape3')
        # print(tgt.shape)

        # THE ERROR: src_len is 150 and key_padding_mask.size(1) is 149
        # BOTH are wrong and should be 128 (currently max_len)
        output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask)
        output = self.linear(output)
        # print('output.shape')
        # print(output.shape)
        return output