class BSpanDecoder(nn.Module): def __init__(self, ntoken, ninp, nhead, nhid, nlayers, reader_, params, dropout=0.5, embedding=None): """ Args: ntoken: vocab size ninp: embedding dimension nhead: number of heads nhid: hidden layer size nlayers: number of layers reader: instance of `Reader` dropout: dropout rate """ super().__init__() from torch.nn import TransformerDecoder, TransformerDecoderLayer self.model_type = 'TransformerDecoder' self.src_mask = None self.pos_encoder = PositionalEncoding(ninp, dropout) decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout) self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers) self.embedding = nn.Embedding(ntoken, ninp) if embedding is None else embedding self.ninp = ninp self.linear = nn.Linear(ninp, ntoken) self.reader_ = reader_ self.params = params self.init_weights() def init_weights(self): initrange = 0.1 # self.embedding.weight.data.uniform_(-initrange, initrange) self.linear.bias.data.zero_() self.linear.weight.data.uniform_(-initrange, initrange) def train(self, t=True): self.transformer_decoder.train(t) def _generate_square_subsequent_mask(self, sz): """ This makes the model autoregressive. When decoding position t, look only at positions 0...t-1 """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def forward(self, tgt, memory): """ Call decoder the decoder should be called repeatedly Args: tgt: input to transformer_decoder, shape: (seq, batch) memory: output from the encoder Returns: output from linear layer, (vocab size), pre softmax """ tgt = tgt.long() go_tokens = torch.zeros((1, tgt.size(1)), dtype=tgt.dtype) + 3 # GO_2 token has index 3 tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis mask = tgt.eq(0).transpose(0,1) # 0 corresponds to <pad> tgt = self.embedding(tgt) * self.ninp tgt = self.pos_encoder(tgt) tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)) output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask) output = self.linear(output) return output
class ResponseDecoder(nn.Module): def __init__(self, ntoken, ninp, nhead, nhid, nlayers, reader_, params, dropout=0.5, embedding=None): """ Args: ntoken: vocab size ninp: embedding dimension nhead: number of heads nhid: hidden layer size nlayers: number of layers reader: instance of `Reader` dropout: dropout rate """ super().__init__() from torch.nn import TransformerDecoder, TransformerDecoderLayer self.model_type = 'TransformerDecoder' self.src_mask = None self.pos_encoder = PositionalEncoding(ninp, dropout) decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout) self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers) self.embedding = nn.Embedding(ntoken, ninp) if embedding is None else embedding self.ninp = ninp self.linear = nn.Linear(ninp, ntoken) self.reader_ = reader_ self.params = params self.init_weights() def init_weights(self): initrange = 0.1 # self.embedding.weight.data.uniform_(-initrange, initrange) self.linear.bias.data.zero_() self.linear.weight.data.uniform_(-initrange, initrange) def train(self, t=True): self.transformer_decoder.train(t) def _generate_square_subsequent_mask(self, sz, bspan_size): # we do not mask the first positions (1 for degree, 1 for <go> token and 'some' for bspan) bspan_size = self.params['bspan_size'] mask = (torch.triu(torch.ones(sz+1, sz+1), diagonal=-(bspan_size+1)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def forward(self, tgt, memory, bspan, degree): """ Call decoder Args: tgt: input to transformer_decoder, shape: (seq_len, batch) memory: output from the encoder degree: degree is the 'output from database', shape: (batch, cfg.degree_size) Returns: output from linear layer, (vocab size), pre softmax """ go_tokens = torch.ones((1, tgt.size(1)), dtype=tgt.dtype) # GO token has index 1 degree_reshaped = torch.zeros((1, tgt.size(1), cfg.embedding_size), dtype=torch.float32) # print('tgt.shape0') # print(tgt.shape) # print('bspan.shape0') # print(bspan.shape) # print('degree_ershaped.shape0') # print(degree_reshaped.shape) # print('go_tokens.shape0') # print(go_tokens.shape) tgt = torch.cat([bspan, go_tokens, tgt], dim=0) # concat bspan, GO and tokenstoken along sequence length axis # TODO pad `tgt` but also think of `degree` which is added later # print('tgt.shape') # print(tgt.shape) mask = torch.cat([torch.ones((1, tgt.size(1)), dtype=torch.int64), tgt]).eq(0).transpose(0,1) # 0 corresponds to <pad> # TODO dimension are wrong # TODO also, final tgt dimension should be cfg.max_ts (128). however, now it is 128 before bspan is concatednated with it # mask = torch.cat([torch.ones((mask.size(0), 1)).bool(), mask]) tgt = self.embedding(tgt) * self.ninp tgt = self.pos_encoder(tgt) # print('tgt.shape2') # print(tgt.shape) # eg. [cheap restaurant EOS_Z1 EOS_Z2 PAD2 .... PAD2 01000 GO1 mask mask mask ..... ] # ... [ ... bspan ... padding degree go .... masking ] # A BIG TODO: the size of `tgt` has to take the size of `bspan` (+1+1 for degree, go) into account bspan_size = self.params['bspan_size'] # always the same tgt_mask = self._generate_square_subsequent_mask(tgt.size(0), bspan_size) # tgt.size(1) is batch size (I know, why dim=1, but nn.Transformer wants it that way) degree_reshaped[0, :, :cfg.degree_size] = degree.transpose(0,1) # add 1 more timestep (the first one as one-hot degree) tgt = torch.cat([degree_reshaped, tgt], dim=0) # concat along sequence lenght axis # print('tgt.shape3') # print(tgt.shape) # THE ERROR: src_len is 150 and key_padding_mask.size(1) is 149 # BOTH are wrong and should be 128 (currently max_len) output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask) output = self.linear(output) # print('output.shape') # print(output.shape) return output