Exemple #1
0
 def get_init_state(self, batsize, device=torch.device("cpu")):
     state = State()
     x = torch.ones(batsize, self.numlayers, self.hdim, device=device)
     state.h = torch.zeros_like(x)
     state.c = torch.zeros_like(x)
     state.levels = torch.zeros_like(x[:, 0, 0])
     return state
Exemple #2
0
    def test_gather_states(self):
        x = [
            State(data=torch.rand(3, 4),
                  substate=State(data=np.asarray([0, 1, 2], dtype="int64")))
            for _ in range(5)
        ]
        x[1].substate.data[:] = [3, 4, 5]
        x[2].substate.data[:] = [6, 7, 8]
        x[3].substate.data[:] = [9, 10, 11]
        x[4].substate.data[:] = [12, 13, 14]
        print(len(x))
        for i in range(5):
            print(x[i].substate.data)

        indexes = torch.tensor([[0, 0, 1, 0], [1, 1, 1, 0], [2, 4, 2, 0]])
        bt = BeamTransition(None)
        y = bt.gather_states(x, indexes)
        print(indexes)
        print(y)

        for ye in y:
            print(ye.substate.data)

        yemat = torch.tensor([ye.substate.data for ye in y]).T
        print(yemat)

        a = torch.arange(0, 15).reshape(5, 3).T
        b = a.gather(1, indexes)
        print(a)
        print(b)

        self.assertTrue(torch.allclose(b, yemat))
Exemple #3
0
    def forward(self, x: State):
        _x = x
        # run genmodel in beam decoder in non-training mode
        if self._beamcache_complete and "eids" in x and self._use_beamcache:
            eids = torch.tensor(x.eids, device=self._beamcache_actions.device)
            predactions = self._beamcache_actions[eids]
        else:
            with torch.no_grad():
                self.beammodel.eval()
                x.start_decoding()
                i = 0
                all_terminated = x.all_terminated()
                while not all_terminated:
                    outprobs, predactions, x, all_terminated = self.beammodel(
                        x, timestep=i)
                    i += 1
                if not self._beamcache_complete and "eids" in _x and self._use_beamcache:
                    for j, eid in enumerate(list(_x.eids)):
                        self._beamcache_actions[eid] = predactions[j:j + 1]

        # if training, add gold to beam
        if self.training:
            golds = _x.get_gold()
            # align gold dims
            pass  # TODO
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            xlmrstates = self.xlmr.extract_features(inptensor)
            inpenc = xlmrstates
            final_enc = xlmrstates[:, 0, :]
            for i in range(len(self.enc_to_dec)):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc)
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_enc
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        return outs[0], x
Exemple #5
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_enc = self.inp_enc(inpembs, mask)
            final_enc = final_enc.view(final_enc.size(0), -1).contiguous()
            final_enc = self.enc_to_dec(final_enc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            mstate.prev_summ = torch.zeros_like(ctx[:, 0])
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0),
                                                  0,
                                                  alphas.size(1),
                                                  device=alphas.device)
            x.stored_attentions = torch.cat(
                [x.stored_attentions,
                 alphas.detach()[:, None, :]], 1)

        return outs[0], x
Exemple #6
0
 def forward(self, inp:torch.Tensor, state:State):
     """
     :param inp:     (batsize, indim)
     :param state:   State with .h, .c of shape (numlayers, batsize, hdim)
     :return:
     """
     x = inp
     _x = self.dropout(x)
     h_nm1 = ((state.h * state.h_dropout) if self.dropout_rec.p > 0 else state.h).transpose(0, 1)
     c_nm1 = ((state.c * state.c_dropout) if self.dropout_rec.p > 0 else state.c).transpose(0, 1)
     out, (h_n, c_n) = self.cell(_x[:, None, :], (h_nm1.contiguous(), c_nm1.contiguous()))
     out = out[:, 0, :]
     state.h = h_n.transpose(0, 1)
     state.c = c_n.transpose(0, 1)
     return out, state
Exemple #7
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            init_states = []
            for i in range(len(final_encs)):
                init_states.append(self.enc_to_dec[i](final_encs[i][0]))
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state
        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)
        return outs[0], x
Exemple #8
0
    def test_slicing_by_indexes(self):
        x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split()))
        print(x.k)
        print(x[[1, 2]].k)
        print(x[[1, 2]].s)

        print(x[np.asarray([1, 2])].k)
        print(x[np.asarray([1, 2])].s)

        print(x[torch.tensor([1, 2])].k)
        print(x[torch.tensor([1, 2])].s)
Exemple #9
0
 def get_init_state(self, batsize, device=torch.device("cpu")):
     state = State()
     x = torch.ones(batsize, self.numlayers, self.hdim, device=device)
     state.h = torch.zeros_like(x)
     state.c = torch.zeros_like(x)
     state.h_dropout = self.dropout_rec(torch.ones_like(x)).clamp(0, 1)
     state.c_dropout = self.dropout_rec(torch.ones_like(x)).clamp(0, 1)
     return state
 def get_init_state(self, batsize, device=torch.device("cpu")):
     main_state = self.main_lstm.get_init_state(batsize, device)
     reduce_state = self.reduce_lstm.get_init_state(batsize, device)
     state = State()
     state.h = main_state.h
     state.c = main_state.c
     state.stack = np.array(range(batsize), dtype="object")
     for i in range(batsize):
         state.stack[i] = []
         state.stack[i].append((main_state[i:i+1], reduce_state[i:i+1]))
     return state
Exemple #11
0
    def test_seting_by_indexes(self):
        x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split()))
        y = State(k=torch.zeros(2, 3), s=np.asarray("o o".split()))

        x[[1, 3]] = y
        print(x.k)
        print(x.s)

        x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split()))
        y = State(k=torch.zeros(2, 3), s=np.asarray("o o".split()))

        x[np.asarray([1, 3])] = y
        print(x.k)
        print(x.s)

        x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split()))
        y = State(k=torch.zeros(2, 3), s=np.asarray("o o".split()))

        x[torch.tensor([1, 3])] = y
        print(x.k)
        print(x.s)
Exemple #12
0
 def test_state_getitem(self):
     x = State()
     x.set(k=torch.randn(5, 4))
     xsub = State()
     xsub.set(v=torch.rand(5, 2))
     x.set(s=xsub)
     x.set(l=np.asarray(["a", "b", "c", "d", "e"]))
     y = x[2]
     print(x[2].k)
     print(x[2].l)
     print(x[2].s.v)
     print(x[:2].k)
     print(x[:2].l)
     print(x[:2].s.v)
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):  # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        if "prevstates" not in mstate:
            _ctx = ctx
            _ctx_mask = ctx_mask
            mstate.prevstates = enc[:, None, :]
        else:
            _ctx = torch.cat([ctx, mstate.prevstates], 1)
            _ctx_mask = torch.cat([
                ctx_mask,
                torch.ones(mstate.prevstates.size(0),
                           mstate.prevstates.size(1),
                           dtype=ctx_mask.dtype,
                           device=ctx_mask.device)
            ], 1)
            mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]],
                                          1)

        alphas, summ, scores = self.att(enc, _ctx, _ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0),
                                                  0,
                                                  alphas.size(1),
                                                  device=alphas.device)
            atts = q.pad_tensors(
                [x.stored_attentions,
                 alphas.detach()[:, None, :]], 2, 0)
            x.stored_attentions = torch.cat(atts, 1)

        return outs[0], x
Exemple #14
0
 def test_state_create(self):
     x = State()
     x.set(k=torch.randn(3, 4))
     xsub = State()
     xsub.set(v=torch.rand(3, 2))
     x.set(s=xsub)
     x.set(l=np.asarray(["sqdf", "qdsf", "qdsf"]))
     print(x)
     print(x._schema_keys)
     print(x.k)
     print(x.s.v)
     print(x.l)
Exemple #15
0
 def test_copy_non_detached(self):
     x = State(k=torch.nn.Parameter(torch.rand(5, 3)))
     y = x.make_copy(detach=False)
     l = y.k.sum()
     l.backward()
     print(x.k.grad)
Exemple #16
0
 def test_copy_detached(self):
     x = State(k=torch.nn.Parameter(torch.rand(5, 3)))
     y = x.make_copy()
     y.k[:] = 0
     print(x.k)
     print(y.k)
Exemple #17
0
 def test_copy_deep(self):
     x = State(k=np.asarray(["a", "b", "c"]))
     y = x.make_copy(deep=False)
     y.k[:] = "q"
     print(x.k)
     print(y.k)
Exemple #18
0
 def test_state_merge(self):
     x = State()
     x.set(k=torch.randn(3, 4))
     xsub = State()
     xsub.set(v=torch.rand(3, 2))
     x.set(s=xsub)
     x.set(l=np.asarray(["sqdf", "qdsf", "qdsf"]))
     y = State()
     y.set(k=torch.randn(2, 4))
     ysub = State()
     ysub.set(v=torch.rand(2, 2))
     y.set(s=ysub)
     y.set(l=np.asarray(["b", "a"]))
     z = State.merge([x, y])
     print(z._schema_keys)
     print(z.k)
     print(x.k)
     print(y.k)
     print(z.s.v)
     print(z.l)
Exemple #19
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
            x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device)
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

            if self.training and q.v(self.beta) < 1:    # sample one of the orders
                golds = x._gold_tensors
                goldsmask = (golds != 0).any(-1).float()
                numgolds = goldsmask.sum(-1)
                gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None]
                selector = gold_select_prob.multinomial(1)[:, 0]
                gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0]
                # interpolate with original gold
                original_gold = x.gold_tensor
                beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long()
                gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None])
                x.gold_tensor = gold_


        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]

        _emb = emb

        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        mstate.decoding_step = mstate.decoding_step + 1

        return outs[0], x
Exemple #20
0
    def test_state_setitem(self):
        x = State()
        x.set(k=torch.randn(5, 4))
        xsub = State()
        xsub.set(v=torch.rand(5, 2))
        x.set(s=xsub)
        x.set(l=np.asarray(["sqdf", "qdsf", "qdsf", "a", "b"]))
        y = State()
        y.set(k=torch.ones(2, 4))
        ysub = State()
        ysub.set(v=torch.ones(2, 2))
        y.set(s=ysub)
        y.set(l=np.asarray(["o", "o"]))

        x[1:3] = y
        print(x.k)
        print(x.s.v)
        print(x.l)
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        # ONR stuff: !!! assumes LISP style queries with parentheses as separate tokens and only parentheses opening and closing clauses
        stack_actions = torch.zeros_like(x.prev_actions)
        stack_actions += (x.prev_actions == self.open_id).long() * +1
        stack_actions += (x.prev_actions == self.close_id).long() * -1

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _ctx = mstate.prev_summ
        else:
            _ctx = torch.zeros(_emb.size(0), 0, device=_emb.device)

        enc, new_rnnstate = self.out_rnn(_emb, _ctx, stack_actions, mstate.rnnstate)

        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        return outs[0], x
Exemple #22
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
            x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device)
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        if not "outenc" in mstate:
            if self.training:
                outtensor = x.gold_tensor
                omask = outtensor != 0
                outembs = self.out_emb_vae(outtensor)
                finalenc, _ = self.out_enc(outembs, omask)
                finalenc, _ = (finalenc + torch.log(omask.float()[:, :, None])).max(1)        # max pool
                # reparam
                mu = self.out_mu(finalenc)
                logvar = self.out_logvar(finalenc)
                std = torch.exp(.5*logvar)
                eps = torch.randn_like(std)
                outenc = mu + eps * std
                mstate.outenc = outenc
                kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
                kld = torch.sum(kld.clamp_min(self.minkl), -1)
                mstate.kld = kld

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]

        if self.training:
            outenc = mstate.outenc
            # outenc = outenc.gather(1, mstate.decoding_step[:, None, None].repeat(1, 1, outenc.size(2)))[:, 0]
        else:
            outenc = torch.randn(emb.size(0), self.zdim, device=emb.device)
        _emb = torch.cat([emb, outenc], 1)

        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        mstate.decoding_step = mstate.decoding_step + 1

        return outs[0], x