def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            xlmrstates = self.xlmr.extract_features(inptensor)
            inpenc = xlmrstates
            final_enc = xlmrstates[:, 0, :]
            for i in range(len(self.enc_to_dec)):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc)
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_enc
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        return outs[0], x
Exemple #2
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_enc = self.inp_enc(inpembs, mask)
            final_enc = final_enc.view(final_enc.size(0), -1).contiguous()
            final_enc = self.enc_to_dec(final_enc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            mstate.prev_summ = torch.zeros_like(ctx[:, 0])
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0),
                                                  0,
                                                  alphas.size(1),
                                                  device=alphas.device)
            x.stored_attentions = torch.cat(
                [x.stored_attentions,
                 alphas.detach()[:, None, :]], 1)

        return outs[0], x
Exemple #3
0
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            init_states = []
            for i in range(len(final_encs)):
                init_states.append(self.enc_to_dec[i](final_encs[i][0]))
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state
        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)
        return outs[0], x
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        # ONR stuff: !!! assumes LISP style queries with parentheses as separate tokens and only parentheses opening and closing clauses
        stack_actions = torch.zeros_like(x.prev_actions)
        stack_actions += (x.prev_actions == self.open_id).long() * +1
        stack_actions += (x.prev_actions == self.close_id).long() * -1

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _ctx = mstate.prev_summ
        else:
            _ctx = torch.zeros(_emb.size(0), 0, device=_emb.device)

        enc, new_rnnstate = self.out_rnn(_emb, _ctx, stack_actions, mstate.rnnstate)

        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.nocopy is True:
            outs = self.out_lin(enc)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        return outs[0], x
Exemple #5
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
            x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device)
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

            if self.training and q.v(self.beta) < 1:    # sample one of the orders
                golds = x._gold_tensors
                goldsmask = (golds != 0).any(-1).float()
                numgolds = goldsmask.sum(-1)
                gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None]
                selector = gold_select_prob.multinomial(1)[:, 0]
                gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0]
                # interpolate with original gold
                original_gold = x.gold_tensor
                beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long()
                gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None])
                x.gold_tensor = gold_


        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]

        _emb = emb

        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        mstate.decoding_step = mstate.decoding_step + 1

        return outs[0], x
    def forward(self, x: State):
        if not "mstate" in x:
            x.mstate = State()
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):  # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(
                emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]
        _emb = emb
        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        if "prevstates" not in mstate:
            _ctx = ctx
            _ctx_mask = ctx_mask
            mstate.prevstates = enc[:, None, :]
        else:
            _ctx = torch.cat([ctx, mstate.prevstates], 1)
            _ctx_mask = torch.cat([
                ctx_mask,
                torch.ones(mstate.prevstates.size(0),
                           mstate.prevstates.size(1),
                           dtype=ctx_mask.dtype,
                           device=ctx_mask.device)
            ], 1)
            mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]],
                                          1)

        alphas, summ, scores = self.att(enc, _ctx, _ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs, ) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0),
                                                  0,
                                                  alphas.size(1),
                                                  device=alphas.device)
            atts = q.pad_tensors(
                [x.stored_attentions,
                 alphas.detach()[:, None, :]], 2, 0)
            x.stored_attentions = torch.cat(atts, 1)

        return outs[0], x
Exemple #7
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
            x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device)
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

        if not "outenc" in mstate:
            if self.training:
                outtensor = x.gold_tensor
                omask = outtensor != 0
                outembs = self.out_emb_vae(outtensor)
                finalenc, _ = self.out_enc(outembs, omask)
                finalenc, _ = (finalenc + torch.log(omask.float()[:, :, None])).max(1)        # max pool
                # reparam
                mu = self.out_mu(finalenc)
                logvar = self.out_logvar(finalenc)
                std = torch.exp(.5*logvar)
                eps = torch.randn_like(std)
                outenc = mu + eps * std
                mstate.outenc = outenc
                kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
                kld = torch.sum(kld.clamp_min(self.minkl), -1)
                mstate.kld = kld

        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]

        if self.training:
            outenc = mstate.outenc
            # outenc = outenc.gather(1, mstate.decoding_step[:, None, None].repeat(1, 1, outenc.size(2)))[:, 0]
        else:
            outenc = torch.randn(emb.size(0), self.zdim, device=emb.device)
        _emb = torch.cat([emb, outenc], 1)

        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        mstate.decoding_step = mstate.decoding_step + 1

        return outs[0], x