コード例 #1
0
ファイル: test_rnn.py プロジェクト: nilesh-c/qelos
 def test_shapes(self):
     batsize = 5
     m = q.RecStack(q.GRUCell(9, 10), q.GRUCell(10, 11))
     x_t = Variable(torch.FloatTensor(np.random.random((batsize, 9))))
     h_tm1_a = Variable(torch.FloatTensor(np.random.random((batsize, 10))))
     h_tm1_b = Variable(torch.FloatTensor(np.random.random((batsize, 11))))
     m.set_init_states(h_tm1_a, h_tm1_b)
     y_t = m(x_t)
     self.assertEqual((batsize, 11), y_t.data.numpy().shape)
コード例 #2
0
ファイル: test_rnn.py プロジェクト: nilesh-c/qelos
    def test_masked_gru_stack(self):
        batsize = 3
        seqlen = 4

        m = q.RecStack(q.GRUCell(9, 10), q.GRUCell(10, 11))
        x = Variable(torch.FloatTensor(np.random.random((batsize, seqlen, 9))))
        h_tm1_a = Variable(torch.FloatTensor(np.random.random((batsize, 10))))
        h_tm1_b = Variable(torch.FloatTensor(np.random.random((batsize, 11))))
        m.set_init_states(h_tm1_a, h_tm1_b)
        m = m.to_layer().return_final().return_all()

        mask_val = np.asarray([[1, 1, 1, 0], [1, 0, 0, 0], [1, 1, 1, 1]])
        mask = Variable(torch.FloatTensor(mask_val))

        y_t, y = m(x, mask=mask)
        self.assertTrue(np.allclose(y_t.data.numpy(), y.data.numpy()[:, -1]))
コード例 #3
0
ファイル: test_decoder.py プロジェクト: nilesh-c/qelos
    def test_shapes(self):
        batsize, seqlen, inpdim = 5, 7, 8
        vocsize, embdim, encdim = 20, 9, 10
        ctxtoinitff = q.Forward(inpdim, encdim)
        coreff = q.Forward(encdim, encdim)
        initstategen = q.Lambda(lambda *x, **kw: coreff(ctxtoinitff(x[1][:, -1, :])), register_modules=coreff)

        decoder_cell = q.AttentionDecoderCell(
            attention=q.Attention().forward_gen(inpdim, encdim+embdim, encdim),
            embedder=nn.Embedding(vocsize, embdim),
            core=q.RecStack(
                q.GRUCell(embdim + inpdim, encdim),
                q.GRUCell(encdim, encdim),
                coreff
            ),
            smo=q.Stack(
                q.Forward(encdim+inpdim, encdim),
                q.Forward(encdim, vocsize),
                q.Softmax()
            ),
            init_state_gen=initstategen,
            ctx_to_decinp=True,
            ctx_to_smo=True,
            state_to_smo=True,
            decinp_to_att=True
        )
        decoder = decoder_cell.to_decoder()

        ctx = np.random.random((batsize, seqlen, inpdim))
        ctx = Variable(torch.FloatTensor(ctx))
        ctxmask = np.ones((batsize, seqlen))
        ctxmask[:, -2:] = 0
        ctxmask[[0, 1], -3:] = 0
        ctxmask = Variable(torch.FloatTensor(ctxmask))
        inp = np.random.randint(0, vocsize, (batsize, seqlen))
        inp = Variable(torch.LongTensor(inp))

        decoded = decoder(inp, ctx, ctxmask)

        self.assertEqual((batsize, seqlen, vocsize), decoded.size())
        self.assertTrue(np.allclose(
            np.sum(decoded.data.numpy(), axis=-1),
            np.ones_like(np.sum(decoded.data.numpy(), axis=-1))))
        print(decoded.size())
コード例 #4
0
ファイル: seq2seq.py プロジェクト: nilesh-c/qelos
def make_decoder(emb, lin, ctxdim=100, embdim=100, dim=100,
                 attmode="bilin", decsplit=False, **kw):
    """ makes decoder
    # attention cell decoder that accepts VNT !!!
    """
    ctxdim = ctxdim if not decsplit else ctxdim // 2
    coreindim = embdim + ctxdim     # if ctx_to_decinp is True else embdim

    coretocritdim = dim if not decsplit else dim // 2
    critdim = dim + embdim          # if decinp_to_att is True else dim

    if attmode == "bilin":
        attention = q.Attention().bilinear_gen(ctxdim, critdim)
    elif attmode == "fwd":
        attention = q.Attention().forward_gen(ctxdim, critdim)
    else:
        raise q.SumTingWongException()

    attcell = q.AttentionDecoderCell(attention=attention,
                                     embedder=emb,
                                     core=q.RecStack(
                                         q.GRUCell(coreindim, dim),
                                         q.GRUCell(dim, dim),
                                     ),
                                     smo=q.Stack(
                                         q.argsave.spec(mask={"mask"}),
                                         lin,
                                         q.argmap.spec(0, mask=["mask"]),
                                         q.LogSoftmax(),
                                         q.argmap.spec(0),
                                     ),
                                     ctx_to_decinp=True,
                                     ctx_to_smo=True,
                                     state_to_smo=True,
                                     decinp_to_att=True,
                                     state_split=decsplit)
    return attcell.to_decoder()
コード例 #5
0
ファイル: encdec_toy.py プロジェクト: nilesh-c/qelos
def main(
    lr=0.5,
    epochs=30,
    batsize=32,
    embdim=90,
    encdim=90,
    mode="cell",  # "fast" or "cell"
    wreg=0.0001,
    cuda=False,
    gpu=1,
):
    if cuda:
        torch.cuda.set_device(gpu)
    usecuda = cuda
    vocsize = 50
    # create datasets tensor
    tt.tick("loading data")
    sequences = np.random.randint(0, vocsize, (batsize * 100, 16))
    # wrap in dataset
    dataset = q.TensorDataset(sequences[:batsize * 80],
                              sequences[:batsize * 80])
    validdataset = q.TensorDataset(sequences[batsize * 80:],
                                   sequences[batsize * 80:])
    dataloader = DataLoader(dataset=dataset, batch_size=batsize, shuffle=True)
    validdataloader = DataLoader(dataset=validdataset,
                                 batch_size=batsize,
                                 shuffle=False)
    tt.tock("data loaded")
    # model
    tt.tick("building model")
    embedder = nn.Embedding(vocsize, embdim)

    encoder = q.RecurrentStack(
        embedder,
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer().return_final(),
    )
    if mode == "fast":
        decoder = q.AttentionDecoder(
            attention=q.Attention().forward_gen(encdim, encdim, encdim),
            embedder=embedder,
            core=q.RecurrentStack(q.GRULayer(embdim, encdim)),
            smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()),
            return_att=True)
    else:
        decoder = q.AttentionDecoderCell(
            attention=q.Attention().forward_gen(encdim, encdim + embdim,
                                                encdim),
            embedder=embedder,
            core=q.RecStack(
                q.GRUCell(embdim + encdim,
                          encdim,
                          use_cudnn_cell=False,
                          rec_batch_norm=None,
                          activation="crelu")),
            smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()),
            att_after_update=False,
            ctx_to_decinp=True,
            decinp_to_att=True,
            return_att=True,
        ).to_decoder()

    m = EncDec(encoder, decoder, mode=mode)

    losses = q.lossarray(q.SeqNLLLoss(ignore_index=None),
                         q.SeqAccuracy(ignore_index=None),
                         q.SeqElemAccuracy(ignore_index=None))
    validlosses = q.lossarray(q.SeqNLLLoss(ignore_index=None),
                              q.SeqAccuracy(ignore_index=None),
                              q.SeqElemAccuracy(ignore_index=None))

    optimizer = torch.optim.Adadelta(m.parameters(), lr=lr, weight_decay=wreg)
    tt.tock("model built")

    q.train(m).cuda(usecuda).train_on(dataloader, losses)\
        .set_batch_transformer(lambda x, y: (x, y[:, :-1], y[:, 1:]))\
        .valid_on(validdataloader, validlosses)\
        .optimizer(optimizer).clip_grad_norm(2.)\
        .train(epochs)

    testdat = np.random.randint(0, vocsize, (batsize, 20))
    testdata = q.var(torch.from_numpy(testdat)).cuda(usecuda).v
    testdata_out = q.var(torch.from_numpy(testdat)).cuda(usecuda).v
    if mode == "cell" and False:
        inv_idx = torch.arange(testdata.size(1) - 1, -1, -1).long()
        testdata = testdata.index_select(1, inv_idx)
    probs, attw = m(testdata, testdata_out[:, :-1])

    def plot(x):
        sns.heatmap(x)
        plt.show()

    embed()