예제 #1
0
 def test_bypass_stack(self):
     data = q.var(np.random.random((3, 5)).astype(dtype="float32")).v
     stack = q.Stack(q.Forward(5, 5), q.argsave.spec(a=0), q.Forward(5, 5),
                     q.Forward(5, 5), q.argmap.spec(0, ["a"]),
                     q.Lambda(lambda x, y: torch.cat([x, y], 1)),
                     q.Forward(10, 7))
     out = stack(data)
     print(out)
     self.assertEqual(out.size(), (3, 7))
예제 #2
0
    def test_shapes(self):
        batsize, seqlen, inpdim = 5, 7, 8
        vocsize, embdim, encdim = 20, 9, 10
        ctxtoinitff = q.Forward(inpdim, encdim)
        coreff = q.Forward(encdim, encdim)
        initstategen = q.Lambda(lambda *x, **kw: coreff(ctxtoinitff(x[1][:, -1, :])), register_modules=coreff)

        decoder_cell = q.AttentionDecoderCell(
            attention=q.Attention().forward_gen(inpdim, encdim+embdim, encdim),
            embedder=nn.Embedding(vocsize, embdim),
            core=q.RecStack(
                q.GRUCell(embdim + inpdim, encdim),
                q.GRUCell(encdim, encdim),
                coreff
            ),
            smo=q.Stack(
                q.Forward(encdim+inpdim, encdim),
                q.Forward(encdim, vocsize),
                q.Softmax()
            ),
            init_state_gen=initstategen,
            ctx_to_decinp=True,
            ctx_to_smo=True,
            state_to_smo=True,
            decinp_to_att=True
        )
        decoder = decoder_cell.to_decoder()

        ctx = np.random.random((batsize, seqlen, inpdim))
        ctx = Variable(torch.FloatTensor(ctx))
        ctxmask = np.ones((batsize, seqlen))
        ctxmask[:, -2:] = 0
        ctxmask[[0, 1], -3:] = 0
        ctxmask = Variable(torch.FloatTensor(ctxmask))
        inp = np.random.randint(0, vocsize, (batsize, seqlen))
        inp = Variable(torch.LongTensor(inp))

        decoded = decoder(inp, ctx, ctxmask)

        self.assertEqual((batsize, seqlen, vocsize), decoded.size())
        self.assertTrue(np.allclose(
            np.sum(decoded.data.numpy(), axis=-1),
            np.ones_like(np.sum(decoded.data.numpy(), axis=-1))))
        print(decoded.size())
예제 #3
0
파일: seq2seq.py 프로젝트: nilesh-c/qelos
def make_decoder(emb, lin, ctxdim=100, embdim=100, dim=100,
                 attmode="bilin", decsplit=False, **kw):
    """ makes decoder
    # attention cell decoder that accepts VNT !!!
    """
    ctxdim = ctxdim if not decsplit else ctxdim // 2
    coreindim = embdim + ctxdim     # if ctx_to_decinp is True else embdim

    coretocritdim = dim if not decsplit else dim // 2
    critdim = dim + embdim          # if decinp_to_att is True else dim

    if attmode == "bilin":
        attention = q.Attention().bilinear_gen(ctxdim, critdim)
    elif attmode == "fwd":
        attention = q.Attention().forward_gen(ctxdim, critdim)
    else:
        raise q.SumTingWongException()

    attcell = q.AttentionDecoderCell(attention=attention,
                                     embedder=emb,
                                     core=q.RecStack(
                                         q.GRUCell(coreindim, dim),
                                         q.GRUCell(dim, dim),
                                     ),
                                     smo=q.Stack(
                                         q.argsave.spec(mask={"mask"}),
                                         lin,
                                         q.argmap.spec(0, mask=["mask"]),
                                         q.LogSoftmax(),
                                         q.argmap.spec(0),
                                     ),
                                     ctx_to_decinp=True,
                                     ctx_to_smo=True,
                                     state_to_smo=True,
                                     decinp_to_att=True,
                                     state_split=decsplit)
    return attcell.to_decoder()
예제 #4
0
    def test_dynamic_bypass_stack(self):
        data = q.var(np.random.random((3, 5)).astype(dtype="float32")).v
        stack = q.Stack()
        nlayers = 5
        for i in range(nlayers):
            stack.add(q.argsave.spec(a=0), q.Forward(5, 5), q.Forward(5, 5),
                      q.argmap.spec(0, ["a"]), q.Lambda(lambda x, y: x + y))
        out = stack(data)
        print(out)
        self.assertEqual(out.size(), (3, 5))

        out.sum().backward()

        forwards = []
        for layer in stack.layers:
            if isinstance(layer, q.Forward):
                self.assertTrue(layer.lin.weight.grad is not None)
                self.assertTrue(layer.lin.bias.grad is not None)
                print(layer.lin.weight.grad.norm(2))
                self.assertTrue(layer.lin.weight.grad.norm(2).data[0] > 0)
                self.assertTrue(layer.lin.bias.grad.norm(2).data[0] > 0)
                forwards.append(layer)

        self.assertEqual(len(forwards), nlayers * 2)
예제 #5
0
def main(
    lr=0.5,
    epochs=30,
    batsize=32,
    embdim=90,
    encdim=90,
    mode="cell",  # "fast" or "cell"
    wreg=0.0001,
    cuda=False,
    gpu=1,
):
    if cuda:
        torch.cuda.set_device(gpu)
    usecuda = cuda
    vocsize = 50
    # create datasets tensor
    tt.tick("loading data")
    sequences = np.random.randint(0, vocsize, (batsize * 100, 16))
    # wrap in dataset
    dataset = q.TensorDataset(sequences[:batsize * 80],
                              sequences[:batsize * 80])
    validdataset = q.TensorDataset(sequences[batsize * 80:],
                                   sequences[batsize * 80:])
    dataloader = DataLoader(dataset=dataset, batch_size=batsize, shuffle=True)
    validdataloader = DataLoader(dataset=validdataset,
                                 batch_size=batsize,
                                 shuffle=False)
    tt.tock("data loaded")
    # model
    tt.tick("building model")
    embedder = nn.Embedding(vocsize, embdim)

    encoder = q.RecurrentStack(
        embedder,
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer(),
        q.SRUCell(encdim).to_layer().return_final(),
    )
    if mode == "fast":
        decoder = q.AttentionDecoder(
            attention=q.Attention().forward_gen(encdim, encdim, encdim),
            embedder=embedder,
            core=q.RecurrentStack(q.GRULayer(embdim, encdim)),
            smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()),
            return_att=True)
    else:
        decoder = q.AttentionDecoderCell(
            attention=q.Attention().forward_gen(encdim, encdim + embdim,
                                                encdim),
            embedder=embedder,
            core=q.RecStack(
                q.GRUCell(embdim + encdim,
                          encdim,
                          use_cudnn_cell=False,
                          rec_batch_norm=None,
                          activation="crelu")),
            smo=q.Stack(nn.Linear(encdim + encdim, vocsize), q.LogSoftmax()),
            att_after_update=False,
            ctx_to_decinp=True,
            decinp_to_att=True,
            return_att=True,
        ).to_decoder()

    m = EncDec(encoder, decoder, mode=mode)

    losses = q.lossarray(q.SeqNLLLoss(ignore_index=None),
                         q.SeqAccuracy(ignore_index=None),
                         q.SeqElemAccuracy(ignore_index=None))
    validlosses = q.lossarray(q.SeqNLLLoss(ignore_index=None),
                              q.SeqAccuracy(ignore_index=None),
                              q.SeqElemAccuracy(ignore_index=None))

    optimizer = torch.optim.Adadelta(m.parameters(), lr=lr, weight_decay=wreg)
    tt.tock("model built")

    q.train(m).cuda(usecuda).train_on(dataloader, losses)\
        .set_batch_transformer(lambda x, y: (x, y[:, :-1], y[:, 1:]))\
        .valid_on(validdataloader, validlosses)\
        .optimizer(optimizer).clip_grad_norm(2.)\
        .train(epochs)

    testdat = np.random.randint(0, vocsize, (batsize, 20))
    testdata = q.var(torch.from_numpy(testdat)).cuda(usecuda).v
    testdata_out = q.var(torch.from_numpy(testdat)).cuda(usecuda).v
    if mode == "cell" and False:
        inv_idx = torch.arange(testdata.size(1) - 1, -1, -1).long()
        testdata = testdata.index_select(1, inv_idx)
    probs, attw = m(testdata, testdata_out[:, :-1])

    def plot(x):
        sns.heatmap(x)
        plt.show()

    embed()