def get_init_state(self, batsize, device=torch.device("cpu")): state = State() x = torch.ones(batsize, self.numlayers, self.hdim, device=device) state.h = torch.zeros_like(x) state.c = torch.zeros_like(x) state.levels = torch.zeros_like(x[:, 0, 0]) return state
def test_gather_states(self): x = [ State(data=torch.rand(3, 4), substate=State(data=np.asarray([0, 1, 2], dtype="int64"))) for _ in range(5) ] x[1].substate.data[:] = [3, 4, 5] x[2].substate.data[:] = [6, 7, 8] x[3].substate.data[:] = [9, 10, 11] x[4].substate.data[:] = [12, 13, 14] print(len(x)) for i in range(5): print(x[i].substate.data) indexes = torch.tensor([[0, 0, 1, 0], [1, 1, 1, 0], [2, 4, 2, 0]]) bt = BeamTransition(None) y = bt.gather_states(x, indexes) print(indexes) print(y) for ye in y: print(ye.substate.data) yemat = torch.tensor([ye.substate.data for ye in y]).T print(yemat) a = torch.arange(0, 15).reshape(5, 3).T b = a.gather(1, indexes) print(a) print(b) self.assertTrue(torch.allclose(b, yemat))
def forward(self, x: State): _x = x # run genmodel in beam decoder in non-training mode if self._beamcache_complete and "eids" in x and self._use_beamcache: eids = torch.tensor(x.eids, device=self._beamcache_actions.device) predactions = self._beamcache_actions[eids] else: with torch.no_grad(): self.beammodel.eval() x.start_decoding() i = 0 all_terminated = x.all_terminated() while not all_terminated: outprobs, predactions, x, all_terminated = self.beammodel( x, timestep=i) i += 1 if not self._beamcache_complete and "eids" in _x and self._use_beamcache: for j, eid in enumerate(list(_x.eids)): self._beamcache_actions[eid] = predactions[j:j + 1] # if training, add gold to beam if self.training: golds = _x.get_gold() # align gold dims pass # TODO
def forward(self, x:State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 xlmrstates = self.xlmr.extract_features(inptensor) inpenc = xlmrstates final_enc = xlmrstates[:, 0, :] for i in range(len(self.enc_to_dec)): # iter over layers _fenc = self.enc_to_dec[i](final_enc) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_enc _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) return outs[0], x
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_enc = self.inp_enc(inpembs, mask) final_enc = final_enc.view(final_enc.size(0), -1).contiguous() final_enc = self.enc_to_dec(final_enc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: mstate.prev_summ = torch.zeros_like(ctx[:, 0]) _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.nocopy is True: outs = self.out_lin(enc) else: outs = self.out_lin(enc, x.inp_tensor, scores) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat( [x.stored_attentions, alphas.detach()[:, None, :]], 1) return outs[0], x
def forward(self, inp:torch.Tensor, state:State): """ :param inp: (batsize, indim) :param state: State with .h, .c of shape (numlayers, batsize, hdim) :return: """ x = inp _x = self.dropout(x) h_nm1 = ((state.h * state.h_dropout) if self.dropout_rec.p > 0 else state.h).transpose(0, 1) c_nm1 = ((state.c * state.c_dropout) if self.dropout_rec.p > 0 else state.c).transpose(0, 1) out, (h_n, c_n) = self.cell(_x[:, None, :], (h_nm1.contiguous(), c_nm1.contiguous())) out = out[:, 0, :] state.h = h_n.transpose(0, 1) state.c = c_n.transpose(0, 1) return out, state
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) inpenc, final_encs = self.inp_enc(inpembs, mask) init_states = [] for i in range(len(final_encs)): init_states.append(self.enc_to_dec[i](final_encs[i][0])) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.nocopy is True: outs = self.out_lin(enc) else: outs = self.out_lin(enc, x.inp_tensor, scores) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) return outs[0], x
def test_slicing_by_indexes(self): x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split())) print(x.k) print(x[[1, 2]].k) print(x[[1, 2]].s) print(x[np.asarray([1, 2])].k) print(x[np.asarray([1, 2])].s) print(x[torch.tensor([1, 2])].k) print(x[torch.tensor([1, 2])].s)
def get_init_state(self, batsize, device=torch.device("cpu")): state = State() x = torch.ones(batsize, self.numlayers, self.hdim, device=device) state.h = torch.zeros_like(x) state.c = torch.zeros_like(x) state.h_dropout = self.dropout_rec(torch.ones_like(x)).clamp(0, 1) state.c_dropout = self.dropout_rec(torch.ones_like(x)).clamp(0, 1) return state
def get_init_state(self, batsize, device=torch.device("cpu")): main_state = self.main_lstm.get_init_state(batsize, device) reduce_state = self.reduce_lstm.get_init_state(batsize, device) state = State() state.h = main_state.h state.c = main_state.c state.stack = np.array(range(batsize), dtype="object") for i in range(batsize): state.stack[i] = [] state.stack[i].append((main_state[i:i+1], reduce_state[i:i+1])) return state
def test_seting_by_indexes(self): x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split())) y = State(k=torch.zeros(2, 3), s=np.asarray("o o".split())) x[[1, 3]] = y print(x.k) print(x.s) x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split())) y = State(k=torch.zeros(2, 3), s=np.asarray("o o".split())) x[np.asarray([1, 3])] = y print(x.k) print(x.s) x = State(k=torch.rand(6, 3), s=np.asarray("a b c d e f".split())) y = State(k=torch.zeros(2, 3), s=np.asarray("o o".split())) x[torch.tensor([1, 3])] = y print(x.k) print(x.s)
def test_state_getitem(self): x = State() x.set(k=torch.randn(5, 4)) xsub = State() xsub.set(v=torch.rand(5, 2)) x.set(s=xsub) x.set(l=np.asarray(["a", "b", "c", "d", "e"])) y = x[2] print(x[2].k) print(x[2].l) print(x[2].s.v) print(x[:2].k) print(x[:2].l) print(x[:2].s.v)
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate if "prevstates" not in mstate: _ctx = ctx _ctx_mask = ctx_mask mstate.prevstates = enc[:, None, :] else: _ctx = torch.cat([ctx, mstate.prevstates], 1) _ctx_mask = torch.cat([ ctx_mask, torch.ones(mstate.prevstates.size(0), mstate.prevstates.size(1), dtype=ctx_mask.dtype, device=ctx_mask.device) ], 1) mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]], 1) alphas, summ, scores = self.att(enc, _ctx, _ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) atts = q.pad_tensors( [x.stored_attentions, alphas.detach()[:, None, :]], 2, 0) x.stored_attentions = torch.cat(atts, 1) return outs[0], x
def test_state_create(self): x = State() x.set(k=torch.randn(3, 4)) xsub = State() xsub.set(v=torch.rand(3, 2)) x.set(s=xsub) x.set(l=np.asarray(["sqdf", "qdsf", "qdsf"])) print(x) print(x._schema_keys) print(x.k) print(x.s.v) print(x.l)
def test_copy_non_detached(self): x = State(k=torch.nn.Parameter(torch.rand(5, 3))) y = x.make_copy(detach=False) l = y.k.sum() l.backward() print(x.k.grad)
def test_copy_detached(self): x = State(k=torch.nn.Parameter(torch.rand(5, 3))) y = x.make_copy() y.k[:] = 0 print(x.k) print(y.k)
def test_copy_deep(self): x = State(k=np.asarray(["a", "b", "c"])) y = x.make_copy(deep=False) y.k[:] = "q" print(x.k) print(y.k)
def test_state_merge(self): x = State() x.set(k=torch.randn(3, 4)) xsub = State() xsub.set(v=torch.rand(3, 2)) x.set(s=xsub) x.set(l=np.asarray(["sqdf", "qdsf", "qdsf"])) y = State() y.set(k=torch.randn(2, 4)) ysub = State() ysub.set(v=torch.rand(2, 2)) y.set(s=ysub) y.set(l=np.asarray(["b", "a"])) z = State.merge([x, y]) print(z._schema_keys) print(z.k) print(x.k) print(y.k) print(z.s.v) print(z.l)
def forward(self, x:State): if not "mstate" in x: x.mstate = State() x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device) mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask if self.training and q.v(self.beta) < 1: # sample one of the orders golds = x._gold_tensors goldsmask = (golds != 0).any(-1).float() numgolds = goldsmask.sum(-1) gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None] selector = gold_select_prob.multinomial(1)[:, 0] gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0] # interpolate with original gold original_gold = x.gold_tensor beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long() gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None]) x.gold_tensor = gold_ ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) mstate.decoding_step = mstate.decoding_step + 1 return outs[0], x
def test_state_setitem(self): x = State() x.set(k=torch.randn(5, 4)) xsub = State() xsub.set(v=torch.rand(5, 2)) x.set(s=xsub) x.set(l=np.asarray(["sqdf", "qdsf", "qdsf", "a", "b"])) y = State() y.set(k=torch.ones(2, 4)) ysub = State() ysub.set(v=torch.ones(2, 2)) y.set(s=ysub) y.set(l=np.asarray(["o", "o"])) x[1:3] = y print(x.k) print(x.s.v) print(x.l)
def forward(self, x:State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state # ONR stuff: !!! assumes LISP style queries with parentheses as separate tokens and only parentheses opening and closing clauses stack_actions = torch.zeros_like(x.prev_actions) stack_actions += (x.prev_actions == self.open_id).long() * +1 stack_actions += (x.prev_actions == self.close_id).long() * -1 if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _ctx = mstate.prev_summ else: _ctx = torch.zeros(_emb.size(0), 0, device=_emb.device) enc, new_rnnstate = self.out_rnn(_emb, _ctx, stack_actions, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.nocopy is True: outs = self.out_lin(enc) else: outs = self.out_lin(enc, x.inp_tensor, scores) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) return outs[0], x
def forward(self, x:State): if not "mstate" in x: x.mstate = State() x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device) mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask if not "outenc" in mstate: if self.training: outtensor = x.gold_tensor omask = outtensor != 0 outembs = self.out_emb_vae(outtensor) finalenc, _ = self.out_enc(outembs, omask) finalenc, _ = (finalenc + torch.log(omask.float()[:, :, None])).max(1) # max pool # reparam mu = self.out_mu(finalenc) logvar = self.out_logvar(finalenc) std = torch.exp(.5*logvar) eps = torch.randn_like(std) outenc = mu + eps * std mstate.outenc = outenc kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) kld = torch.sum(kld.clamp_min(self.minkl), -1) mstate.kld = kld ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] if self.training: outenc = mstate.outenc # outenc = outenc.gather(1, mstate.decoding_step[:, None, None].repeat(1, 1, outenc.size(2)))[:, 0] else: outenc = torch.randn(emb.size(0), self.zdim, device=emb.device) _emb = torch.cat([emb, outenc], 1) if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) mstate.decoding_step = mstate.decoding_step + 1 return outs[0], x