def __init__(self, xlmr, embdim, hdim, numlayers:int=1, dropout=0., sentence_encoder:SequenceEncoder=None, query_encoder:SequenceEncoder=None, feedatt=False, store_attn=True, **kw): super(BasicGenModel, self).__init__(**kw) self.xlmr = xlmr encoder_dim = self.xlmr.args.encoder_embed_dim decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1) self.out_emb = decoder_emb dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout) self.out_rnn = decoder_rnn decoder_out = PtrGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab, str_action_re=None) self.out_lin = decoder_out self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout)) self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential( torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh() ) for _ in range(numlayers)]) self.feedatt = feedatt self.store_attn = store_attn self.reset_parameters()
def create_model( embdim=768, hdim=100, dropout=0., numlayers: int = 1, sentence_encoder: SentenceEncoder = None, query_encoder: FuncQueryEncoder = None, smoothing: float = 0., ): bert = BertModel.from_pretrained("bert-base-uncased") decoder_emb = torch.nn.Embedding( query_encoder.vocab_tokens.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab_tokens.rare_ids, rare_id=1) decoder_rnn = [torch.nn.LSTMCell(embdim, hdim * 2)] for i in range(numlayers - 1): decoder_rnn.append(torch.nn.LSTMCell(hdim * 2, hdim * 2)) decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout) decoder_out = PtrGenOutput(hdim * 2 + embdim, sentence_encoder, query_encoder) attention = q.Attention(q.MatMulDotAttComp(hdim * 2, embdim)) model = BertPtrGenModel(bert, decoder_emb, decoder_rnn, decoder_out, attention) dec = TFActionSeqDecoder(model, smoothing=smoothing) return dec
def test_equivalent_to_qelos_masked(self): m = ScaledDotProductAttention(10, attn_dropout=0) refm = q.Attention().dot_gen().scale(10**0.5) Q = q.var(np.random.random((5, 1, 10)).astype("float32")).v K = q.var(np.random.random((5, 6, 10)).astype("float32")).v M = q.var( np.asarray([ [1, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1], ])).v V = q.var(np.random.random((5, 6, 11)).astype("float32")).v ctx, atn = m(Q, K, V, attn_mask=(-1 * M + 1).byte().data.unsqueeze(1)) refatn = refm.attgen(K, Q, mask=M) refctx = refm.attcon(V, refatn) print(atn) print(refatn) self.assertTrue(np.allclose(atn.data.numpy(), refatn.data.numpy())) self.assertTrue(np.allclose(ctx.data.numpy(), refctx.data.numpy()))
def test_att_splitter(self): batsize, seqlen, datadim, critdim, attdim = 5, 3, 4, 3, 7 crit = torch.FloatTensor(np.random.random((batsize, critdim))) data = torch.FloatTensor(np.random.random((batsize, seqlen, datadim))) att = q.Attention().forward_gen(datadim, critdim, attdim).split_data() attgendata = att.attgen.data_selector(data).numpy() attcondata = att.attcon.data_selector(data).numpy() recdata = np.concatenate([attgendata, attcondata], axis=2) self.assertTrue(np.allclose(data.numpy(), recdata))
def __init__(self, embdim, hdim, numlayers:int=1, dropout=0., zdim=None, sentence_encoder:SequenceEncoder=None, query_encoder:SequenceEncoder=None, feedatt=False, store_attn=True, minkl=0.05, **kw): super(BasicGenModel, self).__init__(**kw) self.minkl = minkl self.embdim, self.hdim, self.numlayers, self.dropout = embdim, hdim, numlayers, dropout self.zdim = embdim if zdim is None else zdim inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D, # p="../../data/glove/glove300uncased") # load glove embeddings where possible into the inner embedding class # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids) self.inp_emb = inpemb encoder_dim = hdim encoder = LSTMEncoder(embdim, hdim // 2, num_layers=numlayers, dropout=dropout, bidirectional=True) # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) self.inp_enc = encoder self.out_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) dec_rnn_in_dim = embdim + self.zdim + (encoder_dim if feedatt else 0) decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout) self.out_rnn = decoder_rnn self.out_emb_vae = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) self.out_enc = LSTMEncoder(embdim, hdim //2, num_layers=numlayers, dropout=dropout, bidirectional=True) # self.out_mu = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim)) # self.out_logvar = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim)) self.out_mu = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim)) self.out_logvar = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim)) decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) self.out_lin = decoder_out self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout)) self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential( torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh() ) for _ in range(numlayers)]) self.feedatt = feedatt self.nocopy = True self.store_attn = store_attn self.reset_parameters()
def create_model(embdim=100, hdim=100, dropout=0., numlayers: int = 1, sentence_encoder: SequenceEncoder = None, query_encoder: SequenceEncoder = None, feedatt=False, nocopy=False): inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) encoder_dim = hdim encoder = LSTMEncoder(embdim, hdim // 2, numlayers, bidirectional=True, dropout=dropout) # encoder = PytorchSeq2SeqWrapper( # torch.nn.LSTM(embdim, hdim, num_layers=numlayers, bidirectional=True, batch_first=True, # dropout=dropout)) decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1) dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, dropout=dropout) # decoder_out = BasicGenOutput(hdim + encoder_dim, query_encoder.vocab) decoder_out = PtrGenOutput(hdim + encoder_dim, out_vocab=query_encoder.vocab) decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) attention = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.0, dropout)) # attention = q.Attention(q.DotAttComp(), dropout=min(0.0, dropout)) enctodec = torch.nn.ModuleList([ torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh()) for _ in range(numlayers) ]) model = BasicGenModel(inpemb, encoder, decoder_emb, decoder_rnn, decoder_out, attention, enc_to_dec=enctodec, feedatt=feedatt, nocopy=nocopy) return model
def test_forward_attgen(self): batsize, seqlen, datadim, critdim, attdim = 5, 3, 4, 3, 7 crit = Variable(torch.FloatTensor(np.random.random( (batsize, critdim)))) data = Variable( torch.FloatTensor(np.random.random((batsize, seqlen, datadim)))) att = q.Attention().forward_gen(datadim, critdim, attdim) m = att.attgen pred = m(data, crit) pred = pred.data.numpy() self.assertEqual(pred.shape, (batsize, seqlen)) self.assertTrue( np.allclose(np.sum(pred, axis=1), np.ones((pred.shape[0], ))))
def make_decoder(emb, lin, ctxdim=100, embdim=100, dim=100, attmode="bilin", decsplit=False, **kw): """ makes decoder # attention cell decoder that accepts VNT !!! """ ctxdim = ctxdim if not decsplit else ctxdim // 2 coreindim = embdim + ctxdim # if ctx_to_decinp is True else embdim coretocritdim = dim if not decsplit else dim // 2 critdim = dim + embdim # if decinp_to_att is True else dim if attmode == "bilin": attention = q.Attention().bilinear_gen(ctxdim, critdim) elif attmode == "fwd": attention = q.Attention().forward_gen(ctxdim, critdim) else: raise q.SumTingWongException() attcell = q.AttentionDecoderCell(attention=attention, embedder=emb, core=q.RecStack( q.GRUCell(coreindim, dim), q.GRUCell(dim, dim), ), smo=q.Stack( q.argsave.spec(mask={"mask"}), lin, q.argmap.spec(0, mask=["mask"]), q.LogSoftmax(), q.argmap.spec(0), ), ctx_to_decinp=True, ctx_to_smo=True, state_to_smo=True, decinp_to_att=True, state_split=decsplit) return attcell.to_decoder()
def create_model(embdim=100, hdim=100, dropout=0., numlayers: int = 1, sentence_encoder: SentenceEncoder = None, query_encoder: SentenceEncoder = None, smoothing: float = 0., feedatt=False): inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) encoder_dim = hdim encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) # encoder = PytorchSeq2SeqWrapper( # torch.nn.LSTM(embdim, hdim, num_layers=numlayers, bidirectional=True, batch_first=True, # dropout=dropout)) decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1) dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = [torch.nn.LSTMCell(dec_rnn_in_dim, hdim)] for i in range(numlayers - 1): decoder_rnn.append(torch.nn.LSTMCell(hdim, hdim)) decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout) decoder_out = BasicGenOutput(hdim + encoder_dim, sentence_encoder.vocab, query_encoder.vocab) attention = q.Attention(q.MatMulDotAttComp(hdim, encoder_dim)) enctodec = torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh()) model = BasicPtrGenModel(inpemb, encoder, decoder_emb, decoder_rnn, decoder_out, attention, dropout=dropout, enc_to_dec=enctodec, feedatt=feedatt) dec = TFTokenSeqDecoder(model, smoothing=smoothing) return dec
def test_equivalent_to_qelos(self): m = ScaledDotProductAttention(10, attn_dropout=0) refm = q.Attention().dot_gen().scale(10**0.5) Q = q.var(np.random.random((5, 4, 10)).astype("float32")).v K = q.var(np.random.random((5, 6, 10)).astype("float32")).v V = q.var(np.random.random((5, 6, 11)).astype("float32")).v ctx, atn = m(Q, K, V) refatn = refm.attgen(K, Q) refctx = refm.attcon(V, refatn) print(atn) print(refatn) self.assertTrue(np.allclose(atn.data.numpy(), refatn.data.numpy())) self.assertTrue(np.allclose(ctx.data.numpy(), refctx.data.numpy()))
def test_shapes(self): batsize, seqlen, inpdim = 5, 7, 8 vocsize, embdim, encdim = 20, 9, 10 ctxtoinitff = q.Forward(inpdim, encdim) coreff = q.Forward(encdim, encdim) initstategen = q.Lambda(lambda *x, **kw: coreff(ctxtoinitff(x[1][:, -1, :])), register_modules=coreff) decoder_cell = q.AttentionDecoderCell( attention=q.Attention().forward_gen(inpdim, encdim+embdim, encdim), embedder=nn.Embedding(vocsize, embdim), core=q.RecStack( q.GRUCell(embdim + inpdim, encdim), q.GRUCell(encdim, encdim), coreff ), smo=q.Stack( q.Forward(encdim+inpdim, encdim), q.Forward(encdim, vocsize), q.Softmax() ), init_state_gen=initstategen, ctx_to_decinp=True, ctx_to_smo=True, state_to_smo=True, decinp_to_att=True ) decoder = decoder_cell.to_decoder() ctx = np.random.random((batsize, seqlen, inpdim)) ctx = Variable(torch.FloatTensor(ctx)) ctxmask = np.ones((batsize, seqlen)) ctxmask[:, -2:] = 0 ctxmask[[0, 1], -3:] = 0 ctxmask = Variable(torch.FloatTensor(ctxmask)) inp = np.random.randint(0, vocsize, (batsize, seqlen)) inp = Variable(torch.LongTensor(inp)) decoded = decoder(inp, ctx, ctxmask) self.assertEqual((batsize, seqlen, vocsize), decoded.size()) self.assertTrue(np.allclose( np.sum(decoded.data.numpy(), axis=-1), np.ones_like(np.sum(decoded.data.numpy(), axis=-1)))) print(decoded.size())
def test_forward_attgen_w_mask(self): batsize, seqlen, datadim, critdim, attdim = 5, 6, 4, 3, 7 crit = Variable(torch.FloatTensor(np.random.random( (batsize, critdim)))) data = Variable( torch.FloatTensor(np.random.random((batsize, seqlen, datadim)))) maskstarts = np.random.randint(1, seqlen, (batsize, )) mask = np.ones((batsize, seqlen), dtype="int32") for i in range(batsize): mask[i, maskstarts[i]:] = 0 mask = Variable(torch.FloatTensor(mask * 1.)) att = q.Attention().forward_gen(datadim, critdim, attdim) m = att.attgen pred = m(data, crit, mask=mask) pred = pred.data.numpy() print(pred) self.assertEqual(pred.shape, (batsize, seqlen)) self.assertTrue(np.allclose(mask.data.numpy(), pred > 0)) self.assertTrue( np.allclose(np.sum(pred, axis=1), np.ones((pred.shape[0], ))))
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super(OldOriginalMultiHeadAttention, self).__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k)) self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k)) self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v)) self.attention = q.Attention().dot_gen().scale(d_model**0.5).dropout( dropout) #ScaledDotProductAttention(d_model) self.layer_norm = q.LayerNormalization(d_model) self.proj = Linear(n_head * d_v, d_model) self.dropout = nn.Dropout(dropout) init.xavier_normal(self.w_qs) init.xavier_normal(self.w_ks) init.xavier_normal(self.w_vs)
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super(MultiHeadAttention, self).__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Parameter( torch.FloatTensor(d_model, d_k * n_head) ) # changes xavier init but probably no difference in the end self.w_ks = nn.Parameter(torch.FloatTensor(d_model, d_k * n_head)) self.w_vs = nn.Parameter(torch.FloatTensor(d_model, d_v * n_head)) self.attention = q.Attention().dot_gen().scale(d_model**0.5).dropout( dropout) #ScaledDotProductAttention(d_model) self.layer_norm = q.LayerNormalization(d_model) self.proj = Linear(n_head * d_v, d_model) self.dropout = nn.Dropout(dropout) init.xavier_normal(self.w_qs) init.xavier_normal(self.w_ks) init.xavier_normal(self.w_vs)
def __init__(self, embdim, hdim, numlayers: int = 1, dropout=0., sentence_encoder: SequenceEncoder = None, query_encoder: SequenceEncoder = None, feedatt=False, store_attn=True, vib_init=False, vib_enc=False, **kw): super(BasicGenModel_VIB, self).__init__(**kw) inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0) # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D, # p="../../data/glove/glove300uncased") # load glove embeddings where possible into the inner embedding class # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids) self.inp_emb = inpemb encoder_dim = hdim * 2 encoder = GRUEncoder(embdim, hdim, num_layers=numlayers, dropout=dropout, bidirectional=True) # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) self.inp_enc = encoder decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) self.out_emb = decoder_emb dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = GRUTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout) self.out_rnn = decoder_rnn decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) self.out_lin = decoder_out self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout)) self.enc_to_dec = torch.nn.ModuleList([ torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh()) for _ in range(numlayers) ]) self.feedatt = feedatt self.nocopy = True self.store_attn = store_attn # VIBs self.vib_init = torch.nn.ModuleList( [VIB(encoder_dim) for _ in range(numlayers)]) if vib_init else None self.vib_enc = VIB_seq(encoder_dim) if vib_enc else None self.reset_parameters()
def main( lr=0.5, epochs=30, batsize=32, embdim=90, encdim=90, mode="cell", # "fast" or "cell" wreg=0.0001, cuda=False, gpu=1, ): if cuda: torch.cuda.set_device(gpu) usecuda = cuda vocsize = 50 # create datasets tensor tt.tick("loading data") sequences = np.random.randint(0, vocsize, (batsize * 100, 16)) # wrap in dataset dataset = q.TensorDataset(sequences[:batsize * 80], sequences[:batsize * 80]) validdataset = q.TensorDataset(sequences[batsize * 80:], sequences[batsize * 80:]) dataloader = DataLoader(dataset=dataset, batch_size=batsize, shuffle=True) validdataloader = DataLoader(dataset=validdataset, batch_size=batsize, shuffle=False) tt.tock("data loaded") # model tt.tick("building model") embedder = nn.Embedding(vocsize, embdim) encoder = q.RecurrentStack( embedder, q.SRUCell(encdim).to_layer(), q.SRUCell(encdim).to_layer(), q.SRUCell(encdim).to_layer(), q.SRUCell(encdim).to_layer().return_final(), ) if mode == "fast": decoder = q.AttentionDecoder( attention=q.Attention().forward_gen(encdim, encdim, encdim), embedder=embedder, core=q.RecurrentStack(q.GRULayer(embdim, encdim)), smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()), return_att=True) else: decoder = q.AttentionDecoderCell( attention=q.Attention().forward_gen(encdim, encdim + embdim, encdim), embedder=embedder, core=q.RecStack( q.GRUCell(embdim + encdim, encdim, use_cudnn_cell=False, rec_batch_norm=None, activation="crelu")), smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()), att_after_update=False, ctx_to_decinp=True, decinp_to_att=True, return_att=True, ).to_decoder() m = EncDec(encoder, decoder, mode=mode) losses = q.lossarray(q.SeqNLLLoss(ignore_index=None), q.SeqAccuracy(ignore_index=None), q.SeqElemAccuracy(ignore_index=None)) validlosses = q.lossarray(q.SeqNLLLoss(ignore_index=None), q.SeqAccuracy(ignore_index=None), q.SeqElemAccuracy(ignore_index=None)) optimizer = torch.optim.Adadelta(m.parameters(), lr=lr, weight_decay=wreg) tt.tock("model built") q.train(m).cuda(usecuda).train_on(dataloader, losses)\ .set_batch_transformer(lambda x, y: (x, y[:, :-1], y[:, 1:]))\ .valid_on(validdataloader, validlosses)\ .optimizer(optimizer).clip_grad_norm(2.)\ .train(epochs) testdat = np.random.randint(0, vocsize, (batsize, 20)) testdata = q.var(torch.from_numpy(testdat)).cuda(usecuda).v testdata_out = q.var(torch.from_numpy(testdat)).cuda(usecuda).v if mode == "cell" and False: inv_idx = torch.arange(testdata.size(1) - 1, -1, -1).long() testdata = testdata.index_select(1, inv_idx) probs, attw = m(testdata, testdata_out[:, :-1]) def plot(x): sns.heatmap(x) plt.show() embed()
def __init__(self, embdim, hdim, numlayers: int = 1, dropout=0., sentence_encoder: SequenceEncoder = None, query_encoder: SequenceEncoder = None, feedatt=False, store_attn=True, **kw): super(BasicGenModel, self).__init__(**kw) inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), 300, padding_idx=0) inpemb = TokenEmb(inpemb, adapt_dims=(300, embdim), rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1) _, covered_word_ids = load_pretrained_embeddings( inpemb.emb, sentence_encoder.vocab.D, p="../../data/glove/glove300uncased" ) # load glove embeddings where possible into the inner embedding class inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids) self.inp_emb = inpemb encoder_dim = hdim encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout) self.inp_enc = encoder decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0) decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1) self.out_emb = decoder_emb dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0) decoder_rnn = [torch.nn.LSTMCell(dec_rnn_in_dim, hdim)] for i in range(numlayers - 1): decoder_rnn.append(torch.nn.LSTMCell(hdim, hdim)) decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout) self.out_rnn = decoder_rnn decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab) # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab) self.out_lin = decoder_out self.att = q.Attention(q.MatMulDotAttComp(hdim, encoder_dim)) self.enc_to_dec = torch.nn.Sequential( torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh()) self.feedatt = feedatt self.nocopy = True self.store_attn = store_attn