Exemplo n.º 1
0
 def test_bypass_stack(self):
     data = q.var(np.random.random((3, 5)).astype(dtype="float32")).v
     stack = q.Stack(q.Forward(5, 5), q.argsave.spec(a=0), q.Forward(5, 5),
                     q.Forward(5, 5), q.argmap.spec(0, ["a"]),
                     q.Lambda(lambda x, y: torch.cat([x, y], 1)),
                     q.Forward(10, 7))
     out = stack(data)
     print(out)
     self.assertEqual(out.size(), (3, 7))
Exemplo n.º 2
0
    def test_shapes(self):
        batsize, seqlen, inpdim = 5, 7, 8
        vocsize, embdim, encdim = 20, 9, 10
        ctxtoinitff = q.Forward(inpdim, encdim)
        coreff = q.Forward(encdim, encdim)
        initstategen = q.Lambda(lambda *x, **kw: coreff(ctxtoinitff(x[1][:, -1, :])), register_modules=coreff)

        decoder_cell = q.AttentionDecoderCell(
            attention=q.Attention().forward_gen(inpdim, encdim+embdim, encdim),
            embedder=nn.Embedding(vocsize, embdim),
            core=q.RecStack(
                q.GRUCell(embdim + inpdim, encdim),
                q.GRUCell(encdim, encdim),
                coreff
            ),
            smo=q.Stack(
                q.Forward(encdim+inpdim, encdim),
                q.Forward(encdim, vocsize),
                q.Softmax()
            ),
            init_state_gen=initstategen,
            ctx_to_decinp=True,
            ctx_to_smo=True,
            state_to_smo=True,
            decinp_to_att=True
        )
        decoder = decoder_cell.to_decoder()

        ctx = np.random.random((batsize, seqlen, inpdim))
        ctx = Variable(torch.FloatTensor(ctx))
        ctxmask = np.ones((batsize, seqlen))
        ctxmask[:, -2:] = 0
        ctxmask[[0, 1], -3:] = 0
        ctxmask = Variable(torch.FloatTensor(ctxmask))
        inp = np.random.randint(0, vocsize, (batsize, seqlen))
        inp = Variable(torch.LongTensor(inp))

        decoded = decoder(inp, ctx, ctxmask)

        self.assertEqual((batsize, seqlen, vocsize), decoded.size())
        self.assertTrue(np.allclose(
            np.sum(decoded.data.numpy(), axis=-1),
            np.ones_like(np.sum(decoded.data.numpy(), axis=-1))))
        print(decoded.size())
Exemplo n.º 3
0
def make_encoder(src_emb, embdim=100, dim=100, **kw):
    """ make encoder
    # concatenating bypass encoder:
    #       embedding  --> top GRU
    #                  --> 1st BiGRU
    #       1st BiGRU  --> top GRU
    #                  --> 2nd BiGRU
    #       2nd BiGRU  --> top GRU
    """
    encoder = q.RecurrentStack(
        src_emb,        # embs, masks
        q.argsave.spec(emb=0, mask=1),
        q.argmap.spec(0, mask=["mask"]),
        q.BidirGRULayer(embdim, dim),
        q.argsave.spec(bypass=0),
        q.argmap.spec(0, mask=["mask"]),
        q.BidirGRULayer(dim * 2, dim),
        q.argmap.spec(0, ["bypass"], ["emb"]),
        q.Lambda(lambda x, y, z: torch.cat([x, y, z], 1)),
        q.argmap.spec(0, mask=["mask"]),
        q.GRULayer(dim * 4 + embdim, dim).return_final(True),
        q.argmap.spec(1, ["mask"], 0),
    )   # returns (all_states, mask, final_state)
    return encoder
Exemplo n.º 4
0
    def test_dynamic_bypass_stack(self):
        data = q.var(np.random.random((3, 5)).astype(dtype="float32")).v
        stack = q.Stack()
        nlayers = 5
        for i in range(nlayers):
            stack.add(q.argsave.spec(a=0), q.Forward(5, 5), q.Forward(5, 5),
                      q.argmap.spec(0, ["a"]), q.Lambda(lambda x, y: x + y))
        out = stack(data)
        print(out)
        self.assertEqual(out.size(), (3, 5))

        out.sum().backward()

        forwards = []
        for layer in stack.layers:
            if isinstance(layer, q.Forward):
                self.assertTrue(layer.lin.weight.grad is not None)
                self.assertTrue(layer.lin.bias.grad is not None)
                print(layer.lin.weight.grad.norm(2))
                self.assertTrue(layer.lin.weight.grad.norm(2).data[0] > 0)
                self.assertTrue(layer.lin.bias.grad.norm(2).data[0] > 0)
                forwards.append(layer)

        self.assertEqual(len(forwards), nlayers * 2)