def test_gather_states(self): x = [ State(data=torch.rand(3, 4), substate=State(data=np.asarray([0, 1, 2], dtype="int64"))) for _ in range(5) ] x[1].substate.data[:] = [3, 4, 5] x[2].substate.data[:] = [6, 7, 8] x[3].substate.data[:] = [9, 10, 11] x[4].substate.data[:] = [12, 13, 14] print(len(x)) for i in range(5): print(x[i].substate.data) indexes = torch.tensor([[0, 0, 1, 0], [1, 1, 1, 0], [2, 4, 2, 0]]) bt = BeamTransition(None) y = bt.gather_states(x, indexes) print(indexes) print(y) for ye in y: print(ye.substate.data) yemat = torch.tensor([ye.substate.data for ye in y]).T print(yemat) a = torch.arange(0, 15).reshape(5, 3).T b = a.gather(1, indexes) print(a) print(b) self.assertTrue(torch.allclose(b, yemat))
def __init__(self, genmodel, scoremodel, beamsize=5, maxtime=100, copy_deep=False, endid=3, use_beamcache=True, **kw): super(BeamReranker, self).__init__(**kw) self.genmodel = genmodel self.beammodel = BeamTransition(self.genmodel, beamsize=beamsize, maxtime=maxtime, copy_deep=copy_deep) self.scoremodel = scoremodel self.endid = endid self._beamcache_actions = {} self._beamcache_attentions = {} # TODO: cache attentions too self._beamcache_complete = False self._use_beamcache = use_beamcache
def test_beam_transition(self): texts = [ "i went to chocolate @END@", "awesome is @END@", "the meaning of life @END@" ] from parseq.vocab import SequenceEncoder se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) x.start_decoding() class Model(TransitionModel): def forward(self, x: BasicDecoderState): outprobs = torch.randn(len(x), x.query_encoder.vocab.number_of_ids()) outprobs = torch.nn.functional.log_softmax(outprobs, -1) return outprobs, x model = Model() beamsize = 50 maxtime = 10 beam_xs = [ x.make_copy(detach=False, deep=True) for _ in range(beamsize) ] beam_states = BeamState(beam_xs) print(len(beam_xs)) print(len(beam_states)) bt = BeamTransition(model, beamsize, maxtime=maxtime) i = 0 _, _, y, _ = bt(x, i) i += 1 _, _, y, _ = bt(y, i) all_terminated = y.all_terminated() while not all_terminated: _, predactions, y, all_terminated = bt(y, i) i += 1 print("timesteps done:") print(i) print(y) print(predactions[0]) for i in range(beamsize): print("-") # print(y.bstates[0].get(i).followed_actions) # print(predactions[0, i, :]) pa = predactions[0, i, :] # print((pa == se.vocab[se.vocab.endtoken]).cumsum(0)) pa = ((pa == se.vocab[se.vocab.endtoken]).long().cumsum(0) < 1).long() * pa yb = y.bstates[0].get(i).followed_actions[0, :] yb = yb * (yb != se.vocab[se.vocab.endtoken]).long() print(pa) print(yb) self.assertTrue(torch.allclose(pa, yb))
class BeamReranker(TransitionModel): def __init__(self, genmodel, scoremodel, beamsize=5, maxtime=100, copy_deep=False, endid=3, use_beamcache=True, **kw): super(BeamReranker, self).__init__(**kw) self.genmodel = genmodel self.beammodel = BeamTransition(self.genmodel, beamsize=beamsize, maxtime=maxtime, copy_deep=copy_deep) self.scoremodel = scoremodel self.endid = endid self._beamcache_actions = {} self._beamcache_attentions = {} # TODO: cache attentions too self._beamcache_complete = False self._use_beamcache = use_beamcache def finalize_beamcache(self): numex = max(self._beamcache_actions.keys()) + 1 ex = self._beamcache_actions[0] beamsize, seqlen = ex.size(1), ex.size(2) beamcache = torch.zeros(numex, beamsize, seqlen, dtype=torch.long, device=ex.device) for k in self._beamcache_actions: beamcache[k] = self._beamcache_actions[k] self._beamcache_actions = beamcache self._beamcache_complete = True def forward(self, x:State): _x = x # run genmodel in beam decoder in non-training mode if self._beamcache_complete and "eids" in x and self._use_beamcache: eids = torch.tensor(x.eids, device=self._beamcache_actions.device) predactions = self._beamcache_actions[eids] else: with torch.no_grad(): self.beammodel.eval() x.start_decoding() i = 0 all_terminated = x.all_terminated() while not all_terminated: outprobs, predactions, x, all_terminated = self.beammodel(x, timestep=i) i += 1 if not self._beamcache_complete and "eids" in _x and self._use_beamcache: for j, eid in enumerate(list(_x.eids)): self._beamcache_actions[eid] = predactions[j:j+1] # if training, add gold to beam if self.training: golds = _x.get_gold() # align gold dims pass # TODO
def test_beam_search_vs_greedy(self): with torch.no_grad(): texts = ["a b"] * 10 from parseq.vocab import SequenceEncoder se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) x.start_decoding() class Model(TransitionModel): transition_tensor = torch.tensor([[0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .01, .99]]) def forward(self, x: BasicDecoderState): prev = x.prev_actions outprobs = self.transition_tensor[prev] outprobs = torch.log(outprobs) return outprobs, x model = Model() beamsize = 50 maxtime = 10 beam_xs = [ x.make_copy(detach=False, deep=True) for _ in range(beamsize) ] beam_states = BeamState(beam_xs) print(len(beam_xs)) print(len(beam_states)) bt = BeamTransition(model, beamsize, maxtime=maxtime) i = 0 _, _, y, _ = bt(x, i) i += 1 _, _, y, _ = bt(y, i) all_terminated = y.all_terminated() while not all_terminated: start_time = time.time() _, _, y, all_terminated = bt(y, i) i += 1 # print(i) end_time = time.time() print(f"{i}: {end_time - start_time}") print(y) print(y.bstates.get(0).followed_actions)
def test_beam_search_stored_probs(self): with torch.no_grad(): texts = ["a b"] * 2 from parseq.vocab import SequenceEncoder se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) x.start_decoding() class Model(TransitionModel): transition_tensor = torch.tensor([[0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .01, .98]]) def forward(self, x: BasicDecoderState): prev = x.prev_actions outprobs = self.transition_tensor[prev] outprobs = torch.log(outprobs) outprobs -= 0.01 * torch.rand_like(outprobs) return outprobs, x model = Model() beamsize = 3 maxtime = 10 beam_xs = [ x.make_copy(detach=False, deep=True) for _ in range(beamsize) ] beam_states = BeamState(beam_xs) print(len(beam_xs)) print(len(beam_states)) bt = BeamTransition(model, beamsize, maxtime=maxtime) i = 0 _, _, y, _ = bt(x, i) i += 1 _, _, y, _ = bt(y, i) all_terminated = False while not all_terminated: start_time = time.time() _, _, y, all_terminated = bt(y, i) i += 1 # print(i) end_time = time.time() print(f"{i}: {end_time - start_time}") print(y) print(y.bstates.get(0).followed_actions) best_actions = y.bstates.get(0).followed_actions best_actionprobs = y.actionprobs.get(0) for i in range(len(best_actions)): print(i) i_prob = 0 for j in range(len(best_actions[i])): action_id = best_actions[i, j] action_prob = best_actionprobs[i, j, action_id] i_prob += action_prob print(i_prob) print(y.bscores[i, 0]) self.assertTrue(torch.allclose(i_prob, y.bscores[i, 0]))