コード例 #1
0
    def __init__(self, xlmr, embdim, hdim, numlayers:int=1, dropout=0.,
                 sentence_encoder:SequenceEncoder=None,
                 query_encoder:SequenceEncoder=None,
                 feedatt=False, store_attn=True, **kw):
        super(BasicGenModel, self).__init__(**kw)

        self.xlmr = xlmr
        encoder_dim = self.xlmr.args.encoder_embed_dim

        decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        decoder_emb = TokenEmb(decoder_emb, rare_token_ids=query_encoder.vocab.rare_ids, rare_id=1)
        self.out_emb = decoder_emb

        dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
        decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout)
        self.out_rnn = decoder_rnn

        decoder_out = PtrGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab)
        decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab, str_action_re=None)
        self.out_lin = decoder_out

        self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout))

        self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hdim),
            torch.nn.Tanh()
        ) for _ in range(numlayers)])

        self.feedatt = feedatt

        self.store_attn = store_attn

        self.reset_parameters()
コード例 #2
0
ファイル: geoquery_bert.py プロジェクト: lukovnikov/funcparse
def create_model(
    embdim=768,
    hdim=100,
    dropout=0.,
    numlayers: int = 1,
    sentence_encoder: SentenceEncoder = None,
    query_encoder: FuncQueryEncoder = None,
    smoothing: float = 0.,
):
    bert = BertModel.from_pretrained("bert-base-uncased")
    decoder_emb = torch.nn.Embedding(
        query_encoder.vocab_tokens.number_of_ids(), embdim, padding_idx=0)
    decoder_emb = TokenEmb(decoder_emb,
                           rare_token_ids=query_encoder.vocab_tokens.rare_ids,
                           rare_id=1)
    decoder_rnn = [torch.nn.LSTMCell(embdim, hdim * 2)]
    for i in range(numlayers - 1):
        decoder_rnn.append(torch.nn.LSTMCell(hdim * 2, hdim * 2))
    decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout)
    decoder_out = PtrGenOutput(hdim * 2 + embdim, sentence_encoder,
                               query_encoder)
    attention = q.Attention(q.MatMulDotAttComp(hdim * 2, embdim))
    model = BertPtrGenModel(bert, decoder_emb, decoder_rnn, decoder_out,
                            attention)
    dec = TFActionSeqDecoder(model, smoothing=smoothing)
    return dec
コード例 #3
0
ファイル: test_attention.py プロジェクト: nilesh-c/qelos
    def test_equivalent_to_qelos_masked(self):
        m = ScaledDotProductAttention(10, attn_dropout=0)
        refm = q.Attention().dot_gen().scale(10**0.5)

        Q = q.var(np.random.random((5, 1, 10)).astype("float32")).v
        K = q.var(np.random.random((5, 6, 10)).astype("float32")).v
        M = q.var(
            np.asarray([
                [1, 0, 0, 0, 0, 0],
                [1, 1, 1, 0, 0, 0],
                [1, 1, 1, 0, 0, 0],
                [1, 1, 1, 1, 0, 0],
                [1, 1, 1, 1, 1, 1],
            ])).v
        V = q.var(np.random.random((5, 6, 11)).astype("float32")).v

        ctx, atn = m(Q, K, V, attn_mask=(-1 * M + 1).byte().data.unsqueeze(1))
        refatn = refm.attgen(K, Q, mask=M)
        refctx = refm.attcon(V, refatn)

        print(atn)
        print(refatn)

        self.assertTrue(np.allclose(atn.data.numpy(), refatn.data.numpy()))
        self.assertTrue(np.allclose(ctx.data.numpy(), refctx.data.numpy()))
コード例 #4
0
ファイル: test_attention.py プロジェクト: nilesh-c/qelos
 def test_att_splitter(self):
     batsize, seqlen, datadim, critdim, attdim = 5, 3, 4, 3, 7
     crit = torch.FloatTensor(np.random.random((batsize, critdim)))
     data = torch.FloatTensor(np.random.random((batsize, seqlen, datadim)))
     att = q.Attention().forward_gen(datadim, critdim, attdim).split_data()
     attgendata = att.attgen.data_selector(data).numpy()
     attcondata = att.attcon.data_selector(data).numpy()
     recdata = np.concatenate([attgendata, attcondata], axis=2)
     self.assertTrue(np.allclose(data.numpy(), recdata))
コード例 #5
0
    def __init__(self, embdim, hdim, numlayers:int=1, dropout=0., zdim=None,
                 sentence_encoder:SequenceEncoder=None,
                 query_encoder:SequenceEncoder=None,
                 feedatt=False, store_attn=True,
                 minkl=0.05, **kw):
        super(BasicGenModel, self).__init__(**kw)

        self.minkl = minkl

        self.embdim, self.hdim, self.numlayers, self.dropout = embdim, hdim, numlayers, dropout
        self.zdim = embdim if zdim is None else zdim

        inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        inpemb = TokenEmb(inpemb, rare_token_ids=sentence_encoder.vocab.rare_ids, rare_id=1)
        # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D,
        #                                                  p="../../data/glove/glove300uncased")  # load glove embeddings where possible into the inner embedding class
        # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids)
        self.inp_emb = inpemb

        encoder_dim = hdim
        encoder = LSTMEncoder(embdim, hdim // 2, num_layers=numlayers, dropout=dropout, bidirectional=True)
        # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout)
        self.inp_enc = encoder

        self.out_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)

        dec_rnn_in_dim = embdim + self.zdim + (encoder_dim if feedatt else 0)
        decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, numlayers, dropout=dropout)
        self.out_rnn = decoder_rnn
        self.out_emb_vae = torch.nn.Embedding(query_encoder.vocab.number_of_ids(), embdim, padding_idx=0)
        self.out_enc = LSTMEncoder(embdim, hdim //2, num_layers=numlayers, dropout=dropout, bidirectional=True)
        # self.out_mu = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim))
        # self.out_logvar = torch.nn.Sequential(torch.nn.Linear(embdim, hdim), torch.nn.Tanh(), torch.nn.Linear(hdim, self.zdim))
        self.out_mu = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim))
        self.out_logvar = torch.nn.Sequential(torch.nn.Linear(hdim, self.zdim))

        decoder_out = BasicGenOutput(hdim + encoder_dim, vocab=query_encoder.vocab)
        # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
        self.out_lin = decoder_out

        self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim), dropout=min(0.1, dropout))

        self.enc_to_dec = torch.nn.ModuleList([torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hdim),
            torch.nn.Tanh()
        ) for _ in range(numlayers)])

        self.feedatt = feedatt
        self.nocopy = True

        self.store_attn = store_attn

        self.reset_parameters()
コード例 #6
0
def create_model(embdim=100,
                 hdim=100,
                 dropout=0.,
                 numlayers: int = 1,
                 sentence_encoder: SequenceEncoder = None,
                 query_encoder: SequenceEncoder = None,
                 feedatt=False,
                 nocopy=False):
    inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(),
                                embdim,
                                padding_idx=0)
    inpemb = TokenEmb(inpemb,
                      rare_token_ids=sentence_encoder.vocab.rare_ids,
                      rare_id=1)
    encoder_dim = hdim
    encoder = LSTMEncoder(embdim,
                          hdim // 2,
                          numlayers,
                          bidirectional=True,
                          dropout=dropout)
    # encoder = PytorchSeq2SeqWrapper(
    #     torch.nn.LSTM(embdim, hdim, num_layers=numlayers, bidirectional=True, batch_first=True,
    #                   dropout=dropout))
    decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(),
                                     embdim,
                                     padding_idx=0)
    decoder_emb = TokenEmb(decoder_emb,
                           rare_token_ids=query_encoder.vocab.rare_ids,
                           rare_id=1)
    dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
    decoder_rnn = LSTMTransition(dec_rnn_in_dim, hdim, dropout=dropout)
    # decoder_out = BasicGenOutput(hdim + encoder_dim, query_encoder.vocab)
    decoder_out = PtrGenOutput(hdim + encoder_dim,
                               out_vocab=query_encoder.vocab)
    decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
    attention = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim),
                            dropout=min(0.0, dropout))
    # attention = q.Attention(q.DotAttComp(), dropout=min(0.0, dropout))
    enctodec = torch.nn.ModuleList([
        torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim),
                            torch.nn.Tanh()) for _ in range(numlayers)
    ])
    model = BasicGenModel(inpemb,
                          encoder,
                          decoder_emb,
                          decoder_rnn,
                          decoder_out,
                          attention,
                          enc_to_dec=enctodec,
                          feedatt=feedatt,
                          nocopy=nocopy)
    return model
コード例 #7
0
ファイル: test_attention.py プロジェクト: nilesh-c/qelos
 def test_forward_attgen(self):
     batsize, seqlen, datadim, critdim, attdim = 5, 3, 4, 3, 7
     crit = Variable(torch.FloatTensor(np.random.random(
         (batsize, critdim))))
     data = Variable(
         torch.FloatTensor(np.random.random((batsize, seqlen, datadim))))
     att = q.Attention().forward_gen(datadim, critdim, attdim)
     m = att.attgen
     pred = m(data, crit)
     pred = pred.data.numpy()
     self.assertEqual(pred.shape, (batsize, seqlen))
     self.assertTrue(
         np.allclose(np.sum(pred, axis=1), np.ones((pred.shape[0], ))))
コード例 #8
0
ファイル: seq2seq.py プロジェクト: nilesh-c/qelos
def make_decoder(emb, lin, ctxdim=100, embdim=100, dim=100,
                 attmode="bilin", decsplit=False, **kw):
    """ makes decoder
    # attention cell decoder that accepts VNT !!!
    """
    ctxdim = ctxdim if not decsplit else ctxdim // 2
    coreindim = embdim + ctxdim     # if ctx_to_decinp is True else embdim

    coretocritdim = dim if not decsplit else dim // 2
    critdim = dim + embdim          # if decinp_to_att is True else dim

    if attmode == "bilin":
        attention = q.Attention().bilinear_gen(ctxdim, critdim)
    elif attmode == "fwd":
        attention = q.Attention().forward_gen(ctxdim, critdim)
    else:
        raise q.SumTingWongException()

    attcell = q.AttentionDecoderCell(attention=attention,
                                     embedder=emb,
                                     core=q.RecStack(
                                         q.GRUCell(coreindim, dim),
                                         q.GRUCell(dim, dim),
                                     ),
                                     smo=q.Stack(
                                         q.argsave.spec(mask={"mask"}),
                                         lin,
                                         q.argmap.spec(0, mask=["mask"]),
                                         q.LogSoftmax(),
                                         q.argmap.spec(0),
                                     ),
                                     ctx_to_decinp=True,
                                     ctx_to_smo=True,
                                     state_to_smo=True,
                                     decinp_to_att=True,
                                     state_split=decsplit)
    return attcell.to_decoder()
コード例 #9
0
def create_model(embdim=100,
                 hdim=100,
                 dropout=0.,
                 numlayers: int = 1,
                 sentence_encoder: SentenceEncoder = None,
                 query_encoder: SentenceEncoder = None,
                 smoothing: float = 0.,
                 feedatt=False):
    inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(),
                                embdim,
                                padding_idx=0)
    inpemb = TokenEmb(inpemb,
                      rare_token_ids=sentence_encoder.vocab.rare_ids,
                      rare_id=1)
    encoder_dim = hdim
    encoder = q.LSTMEncoder(embdim,
                            *([encoder_dim // 2] * numlayers),
                            bidir=True,
                            dropout_in=dropout)
    # encoder = PytorchSeq2SeqWrapper(
    #     torch.nn.LSTM(embdim, hdim, num_layers=numlayers, bidirectional=True, batch_first=True,
    #                   dropout=dropout))
    decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(),
                                     embdim,
                                     padding_idx=0)
    decoder_emb = TokenEmb(decoder_emb,
                           rare_token_ids=query_encoder.vocab.rare_ids,
                           rare_id=1)
    dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
    decoder_rnn = [torch.nn.LSTMCell(dec_rnn_in_dim, hdim)]
    for i in range(numlayers - 1):
        decoder_rnn.append(torch.nn.LSTMCell(hdim, hdim))
    decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout)
    decoder_out = BasicGenOutput(hdim + encoder_dim, sentence_encoder.vocab,
                                 query_encoder.vocab)
    attention = q.Attention(q.MatMulDotAttComp(hdim, encoder_dim))
    enctodec = torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim),
                                   torch.nn.Tanh())
    model = BasicPtrGenModel(inpemb,
                             encoder,
                             decoder_emb,
                             decoder_rnn,
                             decoder_out,
                             attention,
                             dropout=dropout,
                             enc_to_dec=enctodec,
                             feedatt=feedatt)
    dec = TFTokenSeqDecoder(model, smoothing=smoothing)
    return dec
コード例 #10
0
ファイル: test_attention.py プロジェクト: nilesh-c/qelos
    def test_equivalent_to_qelos(self):
        m = ScaledDotProductAttention(10, attn_dropout=0)
        refm = q.Attention().dot_gen().scale(10**0.5)

        Q = q.var(np.random.random((5, 4, 10)).astype("float32")).v
        K = q.var(np.random.random((5, 6, 10)).astype("float32")).v
        V = q.var(np.random.random((5, 6, 11)).astype("float32")).v

        ctx, atn = m(Q, K, V)
        refatn = refm.attgen(K, Q)
        refctx = refm.attcon(V, refatn)

        print(atn)
        print(refatn)

        self.assertTrue(np.allclose(atn.data.numpy(), refatn.data.numpy()))
        self.assertTrue(np.allclose(ctx.data.numpy(), refctx.data.numpy()))
コード例 #11
0
ファイル: test_decoder.py プロジェクト: nilesh-c/qelos
    def test_shapes(self):
        batsize, seqlen, inpdim = 5, 7, 8
        vocsize, embdim, encdim = 20, 9, 10
        ctxtoinitff = q.Forward(inpdim, encdim)
        coreff = q.Forward(encdim, encdim)
        initstategen = q.Lambda(lambda *x, **kw: coreff(ctxtoinitff(x[1][:, -1, :])), register_modules=coreff)

        decoder_cell = q.AttentionDecoderCell(
            attention=q.Attention().forward_gen(inpdim, encdim+embdim, encdim),
            embedder=nn.Embedding(vocsize, embdim),
            core=q.RecStack(
                q.GRUCell(embdim + inpdim, encdim),
                q.GRUCell(encdim, encdim),
                coreff
            ),
            smo=q.Stack(
                q.Forward(encdim+inpdim, encdim),
                q.Forward(encdim, vocsize),
                q.Softmax()
            ),
            init_state_gen=initstategen,
            ctx_to_decinp=True,
            ctx_to_smo=True,
            state_to_smo=True,
            decinp_to_att=True
        )
        decoder = decoder_cell.to_decoder()

        ctx = np.random.random((batsize, seqlen, inpdim))
        ctx = Variable(torch.FloatTensor(ctx))
        ctxmask = np.ones((batsize, seqlen))
        ctxmask[:, -2:] = 0
        ctxmask[[0, 1], -3:] = 0
        ctxmask = Variable(torch.FloatTensor(ctxmask))
        inp = np.random.randint(0, vocsize, (batsize, seqlen))
        inp = Variable(torch.LongTensor(inp))

        decoded = decoder(inp, ctx, ctxmask)

        self.assertEqual((batsize, seqlen, vocsize), decoded.size())
        self.assertTrue(np.allclose(
            np.sum(decoded.data.numpy(), axis=-1),
            np.ones_like(np.sum(decoded.data.numpy(), axis=-1))))
        print(decoded.size())
コード例 #12
0
ファイル: test_attention.py プロジェクト: nilesh-c/qelos
 def test_forward_attgen_w_mask(self):
     batsize, seqlen, datadim, critdim, attdim = 5, 6, 4, 3, 7
     crit = Variable(torch.FloatTensor(np.random.random(
         (batsize, critdim))))
     data = Variable(
         torch.FloatTensor(np.random.random((batsize, seqlen, datadim))))
     maskstarts = np.random.randint(1, seqlen, (batsize, ))
     mask = np.ones((batsize, seqlen), dtype="int32")
     for i in range(batsize):
         mask[i, maskstarts[i]:] = 0
     mask = Variable(torch.FloatTensor(mask * 1.))
     att = q.Attention().forward_gen(datadim, critdim, attdim)
     m = att.attgen
     pred = m(data, crit, mask=mask)
     pred = pred.data.numpy()
     print(pred)
     self.assertEqual(pred.shape, (batsize, seqlen))
     self.assertTrue(np.allclose(mask.data.numpy(), pred > 0))
     self.assertTrue(
         np.allclose(np.sum(pred, axis=1), np.ones((pred.shape[0], ))))
コード例 #13
0
ファイル: aiayn.py プロジェクト: nilesh-c/qelos
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super(OldOriginalMultiHeadAttention, self).__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
        self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
        self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))

        self.attention = q.Attention().dot_gen().scale(d_model**0.5).dropout(
            dropout)  #ScaledDotProductAttention(d_model)
        self.layer_norm = q.LayerNormalization(d_model)
        self.proj = Linear(n_head * d_v, d_model)

        self.dropout = nn.Dropout(dropout)

        init.xavier_normal(self.w_qs)
        init.xavier_normal(self.w_ks)
        init.xavier_normal(self.w_vs)
コード例 #14
0
ファイル: aiayn.py プロジェクト: nilesh-c/qelos
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Parameter(
            torch.FloatTensor(d_model, d_k * n_head)
        )  # changes xavier init but probably no difference in the end
        self.w_ks = nn.Parameter(torch.FloatTensor(d_model, d_k * n_head))
        self.w_vs = nn.Parameter(torch.FloatTensor(d_model, d_v * n_head))

        self.attention = q.Attention().dot_gen().scale(d_model**0.5).dropout(
            dropout)  #ScaledDotProductAttention(d_model)
        self.layer_norm = q.LayerNormalization(d_model)
        self.proj = Linear(n_head * d_v, d_model)

        self.dropout = nn.Dropout(dropout)

        init.xavier_normal(self.w_qs)
        init.xavier_normal(self.w_ks)
        init.xavier_normal(self.w_vs)
コード例 #15
0
ファイル: lcquad_vib.py プロジェクト: saist1993/parseq
    def __init__(self,
                 embdim,
                 hdim,
                 numlayers: int = 1,
                 dropout=0.,
                 sentence_encoder: SequenceEncoder = None,
                 query_encoder: SequenceEncoder = None,
                 feedatt=False,
                 store_attn=True,
                 vib_init=False,
                 vib_enc=False,
                 **kw):
        super(BasicGenModel_VIB, self).__init__(**kw)

        inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(),
                                    embdim,
                                    padding_idx=0)

        # _, covered_word_ids = load_pretrained_embeddings(inpemb.emb, sentence_encoder.vocab.D,
        #                                                  p="../../data/glove/glove300uncased")  # load glove embeddings where possible into the inner embedding class
        # inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids)
        self.inp_emb = inpemb

        encoder_dim = hdim * 2
        encoder = GRUEncoder(embdim,
                             hdim,
                             num_layers=numlayers,
                             dropout=dropout,
                             bidirectional=True)
        # encoder = q.LSTMEncoder(embdim, *([encoder_dim // 2] * numlayers), bidir=True, dropout_in=dropout)
        self.inp_enc = encoder

        decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(),
                                         embdim,
                                         padding_idx=0)
        self.out_emb = decoder_emb

        dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
        decoder_rnn = GRUTransition(dec_rnn_in_dim,
                                    hdim,
                                    numlayers,
                                    dropout=dropout)
        self.out_rnn = decoder_rnn

        decoder_out = BasicGenOutput(hdim + encoder_dim,
                                     vocab=query_encoder.vocab)
        # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
        self.out_lin = decoder_out

        self.att = q.Attention(q.SimpleFwdAttComp(hdim, encoder_dim, hdim),
                               dropout=min(0.1, dropout))

        self.enc_to_dec = torch.nn.ModuleList([
            torch.nn.Sequential(torch.nn.Linear(encoder_dim, hdim),
                                torch.nn.Tanh()) for _ in range(numlayers)
        ])

        self.feedatt = feedatt
        self.nocopy = True

        self.store_attn = store_attn

        # VIBs
        self.vib_init = torch.nn.ModuleList(
            [VIB(encoder_dim) for _ in range(numlayers)]) if vib_init else None
        self.vib_enc = VIB_seq(encoder_dim) if vib_enc else None

        self.reset_parameters()
コード例 #16
0
ファイル: encdec_toy.py プロジェクト: nilesh-c/qelos
def main(
    lr=0.5,
    epochs=30,
    batsize=32,
    embdim=90,
    encdim=90,
    mode="cell",  # "fast" or "cell"
    wreg=0.0001,
    cuda=False,
    gpu=1,
):
    if cuda:
        torch.cuda.set_device(gpu)
    usecuda = cuda
    vocsize = 50
    # create datasets tensor
    tt.tick("loading data")
    sequences = np.random.randint(0, vocsize, (batsize * 100, 16))
    # wrap in dataset
    dataset = q.TensorDataset(sequences[:batsize * 80],
                              sequences[:batsize * 80])
    validdataset = q.TensorDataset(sequences[batsize * 80:],
                                   sequences[batsize * 80:])
    dataloader = DataLoader(dataset=dataset, batch_size=batsize, shuffle=True)
    validdataloader = DataLoader(dataset=validdataset,
                                 batch_size=batsize,
                                 shuffle=False)
    tt.tock("data loaded")
    # model
    tt.tick("building model")
    embedder = nn.Embedding(vocsize, embdim)

    encoder = q.RecurrentStack(
        embedder,
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer().return_final(),
    )
    if mode == "fast":
        decoder = q.AttentionDecoder(
            attention=q.Attention().forward_gen(encdim, encdim, encdim),
            embedder=embedder,
            core=q.RecurrentStack(q.GRULayer(embdim, encdim)),
            smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()),
            return_att=True)
    else:
        decoder = q.AttentionDecoderCell(
            attention=q.Attention().forward_gen(encdim, encdim + embdim,
                                                encdim),
            embedder=embedder,
            core=q.RecStack(
                q.GRUCell(embdim + encdim,
                          encdim,
                          use_cudnn_cell=False,
                          rec_batch_norm=None,
                          activation="crelu")),
            smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()),
            att_after_update=False,
            ctx_to_decinp=True,
            decinp_to_att=True,
            return_att=True,
        ).to_decoder()

    m = EncDec(encoder, decoder, mode=mode)

    losses = q.lossarray(q.SeqNLLLoss(ignore_index=None),
                         q.SeqAccuracy(ignore_index=None),
                         q.SeqElemAccuracy(ignore_index=None))
    validlosses = q.lossarray(q.SeqNLLLoss(ignore_index=None),
                              q.SeqAccuracy(ignore_index=None),
                              q.SeqElemAccuracy(ignore_index=None))

    optimizer = torch.optim.Adadelta(m.parameters(), lr=lr, weight_decay=wreg)
    tt.tock("model built")

    q.train(m).cuda(usecuda).train_on(dataloader, losses)\
        .set_batch_transformer(lambda x, y: (x, y[:, :-1], y[:, 1:]))\
        .valid_on(validdataloader, validlosses)\
        .optimizer(optimizer).clip_grad_norm(2.)\
        .train(epochs)

    testdat = np.random.randint(0, vocsize, (batsize, 20))
    testdata = q.var(torch.from_numpy(testdat)).cuda(usecuda).v
    testdata_out = q.var(torch.from_numpy(testdat)).cuda(usecuda).v
    if mode == "cell" and False:
        inv_idx = torch.arange(testdata.size(1) - 1, -1, -1).long()
        testdata = testdata.index_select(1, inv_idx)
    probs, attw = m(testdata, testdata_out[:, :-1])

    def plot(x):
        sns.heatmap(x)
        plt.show()

    embed()
コード例 #17
0
ファイル: overnight_basic.py プロジェクト: saist1993/parseq
    def __init__(self,
                 embdim,
                 hdim,
                 numlayers: int = 1,
                 dropout=0.,
                 sentence_encoder: SequenceEncoder = None,
                 query_encoder: SequenceEncoder = None,
                 feedatt=False,
                 store_attn=True,
                 **kw):
        super(BasicGenModel, self).__init__(**kw)

        inpemb = torch.nn.Embedding(sentence_encoder.vocab.number_of_ids(),
                                    300,
                                    padding_idx=0)
        inpemb = TokenEmb(inpemb,
                          adapt_dims=(300, embdim),
                          rare_token_ids=sentence_encoder.vocab.rare_ids,
                          rare_id=1)
        _, covered_word_ids = load_pretrained_embeddings(
            inpemb.emb,
            sentence_encoder.vocab.D,
            p="../../data/glove/glove300uncased"
        )  # load glove embeddings where possible into the inner embedding class
        inpemb._do_rare(inpemb.rare_token_ids - covered_word_ids)
        self.inp_emb = inpemb

        encoder_dim = hdim
        encoder = q.LSTMEncoder(embdim,
                                *([encoder_dim // 2] * numlayers),
                                bidir=True,
                                dropout_in=dropout)
        self.inp_enc = encoder

        decoder_emb = torch.nn.Embedding(query_encoder.vocab.number_of_ids(),
                                         embdim,
                                         padding_idx=0)
        decoder_emb = TokenEmb(decoder_emb,
                               rare_token_ids=query_encoder.vocab.rare_ids,
                               rare_id=1)
        self.out_emb = decoder_emb

        dec_rnn_in_dim = embdim + (encoder_dim if feedatt else 0)
        decoder_rnn = [torch.nn.LSTMCell(dec_rnn_in_dim, hdim)]
        for i in range(numlayers - 1):
            decoder_rnn.append(torch.nn.LSTMCell(hdim, hdim))
        decoder_rnn = LSTMCellTransition(*decoder_rnn, dropout=dropout)
        self.out_rnn = decoder_rnn

        decoder_out = BasicGenOutput(hdim + encoder_dim,
                                     vocab=query_encoder.vocab)
        # decoder_out.build_copy_maps(inp_vocab=sentence_encoder.vocab)
        self.out_lin = decoder_out

        self.att = q.Attention(q.MatMulDotAttComp(hdim, encoder_dim))

        self.enc_to_dec = torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hdim), torch.nn.Tanh())

        self.feedatt = feedatt
        self.nocopy = True

        self.store_attn = store_attn