def test_bypass_stack(self): data = q.var(np.random.random((3, 5)).astype(dtype="float32")).v stack = q.Stack(q.Forward(5, 5), q.argsave.spec(a=0), q.Forward(5, 5), q.Forward(5, 5), q.argmap.spec(0, ["a"]), q.Lambda(lambda x, y: torch.cat([x, y], 1)), q.Forward(10, 7)) out = stack(data) print(out) self.assertEqual(out.size(), (3, 7))
def test_shapes(self): batsize, seqlen, inpdim = 5, 7, 8 vocsize, embdim, encdim = 20, 9, 10 ctxtoinitff = q.Forward(inpdim, encdim) coreff = q.Forward(encdim, encdim) initstategen = q.Lambda(lambda *x, **kw: coreff(ctxtoinitff(x[1][:, -1, :])), register_modules=coreff) decoder_cell = q.AttentionDecoderCell( attention=q.Attention().forward_gen(inpdim, encdim+embdim, encdim), embedder=nn.Embedding(vocsize, embdim), core=q.RecStack( q.GRUCell(embdim + inpdim, encdim), q.GRUCell(encdim, encdim), coreff ), smo=q.Stack( q.Forward(encdim+inpdim, encdim), q.Forward(encdim, vocsize), q.Softmax() ), init_state_gen=initstategen, ctx_to_decinp=True, ctx_to_smo=True, state_to_smo=True, decinp_to_att=True ) decoder = decoder_cell.to_decoder() ctx = np.random.random((batsize, seqlen, inpdim)) ctx = Variable(torch.FloatTensor(ctx)) ctxmask = np.ones((batsize, seqlen)) ctxmask[:, -2:] = 0 ctxmask[[0, 1], -3:] = 0 ctxmask = Variable(torch.FloatTensor(ctxmask)) inp = np.random.randint(0, vocsize, (batsize, seqlen)) inp = Variable(torch.LongTensor(inp)) decoded = decoder(inp, ctx, ctxmask) self.assertEqual((batsize, seqlen, vocsize), decoded.size()) self.assertTrue(np.allclose( np.sum(decoded.data.numpy(), axis=-1), np.ones_like(np.sum(decoded.data.numpy(), axis=-1)))) print(decoded.size())
def make_encoder(src_emb, embdim=100, dim=100, **kw): """ make encoder # concatenating bypass encoder: # embedding --> top GRU # --> 1st BiGRU # 1st BiGRU --> top GRU # --> 2nd BiGRU # 2nd BiGRU --> top GRU """ encoder = q.RecurrentStack( src_emb, # embs, masks q.argsave.spec(emb=0, mask=1), q.argmap.spec(0, mask=["mask"]), q.BidirGRULayer(embdim, dim), q.argsave.spec(bypass=0), q.argmap.spec(0, mask=["mask"]), q.BidirGRULayer(dim * 2, dim), q.argmap.spec(0, ["bypass"], ["emb"]), q.Lambda(lambda x, y, z: torch.cat([x, y, z], 1)), q.argmap.spec(0, mask=["mask"]), q.GRULayer(dim * 4 + embdim, dim).return_final(True), q.argmap.spec(1, ["mask"], 0), ) # returns (all_states, mask, final_state) return encoder
def test_dynamic_bypass_stack(self): data = q.var(np.random.random((3, 5)).astype(dtype="float32")).v stack = q.Stack() nlayers = 5 for i in range(nlayers): stack.add(q.argsave.spec(a=0), q.Forward(5, 5), q.Forward(5, 5), q.argmap.spec(0, ["a"]), q.Lambda(lambda x, y: x + y)) out = stack(data) print(out) self.assertEqual(out.size(), (3, 5)) out.sum().backward() forwards = [] for layer in stack.layers: if isinstance(layer, q.Forward): self.assertTrue(layer.lin.weight.grad is not None) self.assertTrue(layer.lin.bias.grad is not None) print(layer.lin.weight.grad.norm(2)) self.assertTrue(layer.lin.weight.grad.norm(2).data[0] > 0) self.assertTrue(layer.lin.bias.grad.norm(2).data[0] > 0) forwards.append(layer) self.assertEqual(len(forwards), nlayers * 2)