def forward(self, x:State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 xlmrstates = self.xlmr.extract_features(inptensor) inpenc = xlmrstates final_enc = xlmrstates[:, 0, :] for i in range(len(self.enc_to_dec)): # iter over layers _fenc = self.enc_to_dec[i](final_enc) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_enc _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) return outs[0], x
def forward(self, x:State): if not "mstate" in x: x.mstate = State() x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device) mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask if self.training and q.v(self.beta) < 1: # sample one of the orders golds = x._gold_tensors goldsmask = (golds != 0).any(-1).float() numgolds = goldsmask.sum(-1) gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None] selector = gold_select_prob.multinomial(1)[:, 0] gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0] # interpolate with original gold original_gold = x.gold_tensor beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long() gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None]) x.gold_tensor = gold_ ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) mstate.decoding_step = mstate.decoding_step + 1 return outs[0], x
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate if "prevstates" not in mstate: _ctx = ctx _ctx_mask = ctx_mask mstate.prevstates = enc[:, None, :] else: _ctx = torch.cat([ctx, mstate.prevstates], 1) _ctx_mask = torch.cat([ ctx_mask, torch.ones(mstate.prevstates.size(0), mstate.prevstates.size(1), dtype=ctx_mask.dtype, device=ctx_mask.device) ], 1) mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]], 1) alphas, summ, scores = self.att(enc, _ctx, _ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) atts = q.pad_tensors( [x.stored_attentions, alphas.detach()[:, None, :]], 2, 0) x.stored_attentions = torch.cat(atts, 1) return outs[0], x
def forward(self, x:State): if not "mstate" in x: x.mstate = State() x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device) mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask if not "outenc" in mstate: if self.training: outtensor = x.gold_tensor omask = outtensor != 0 outembs = self.out_emb_vae(outtensor) finalenc, _ = self.out_enc(outembs, omask) finalenc, _ = (finalenc + torch.log(omask.float()[:, :, None])).max(1) # max pool # reparam mu = self.out_mu(finalenc) logvar = self.out_logvar(finalenc) std = torch.exp(.5*logvar) eps = torch.randn_like(std) outenc = mu + eps * std mstate.outenc = outenc kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) kld = torch.sum(kld.clamp_min(self.minkl), -1) mstate.kld = kld ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] if self.training: outenc = mstate.outenc # outenc = outenc.gather(1, mstate.decoding_step[:, None, None].repeat(1, 1, outenc.size(2)))[:, 0] else: outenc = torch.randn(emb.size(0), self.zdim, device=emb.device) _emb = torch.cat([emb, outenc], 1) if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) mstate.decoding_step = mstate.decoding_step + 1 return outs[0], x