Esempio n. 1
0
    def test_gather_states(self):
        x = [
            State(data=torch.rand(3, 4),
                  substate=State(data=np.asarray([0, 1, 2], dtype="int64")))
            for _ in range(5)
        ]
        x[1].substate.data[:] = [3, 4, 5]
        x[2].substate.data[:] = [6, 7, 8]
        x[3].substate.data[:] = [9, 10, 11]
        x[4].substate.data[:] = [12, 13, 14]
        print(len(x))
        for i in range(5):
            print(x[i].substate.data)

        indexes = torch.tensor([[0, 0, 1, 0], [1, 1, 1, 0], [2, 4, 2, 0]])
        bt = BeamTransition(None)
        y = bt.gather_states(x, indexes)
        print(indexes)
        print(y)

        for ye in y:
            print(ye.substate.data)

        yemat = torch.tensor([ye.substate.data for ye in y]).T
        print(yemat)

        a = torch.arange(0, 15).reshape(5, 3).T
        b = a.gather(1, indexes)
        print(a)
        print(b)

        self.assertTrue(torch.allclose(b, yemat))
Esempio n. 2
0
 def __init__(self, genmodel, scoremodel, beamsize=5, maxtime=100,
              copy_deep=False, endid=3, use_beamcache=True, **kw):
     super(BeamReranker, self).__init__(**kw)
     self.genmodel = genmodel
     self.beammodel = BeamTransition(self.genmodel, beamsize=beamsize, maxtime=maxtime, copy_deep=copy_deep)
     self.scoremodel = scoremodel
     self.endid = endid
     self._beamcache_actions = {}
     self._beamcache_attentions = {}         # TODO: cache attentions too
     self._beamcache_complete = False
     self._use_beamcache = use_beamcache
Esempio n. 3
0
    def test_beam_transition(self):
        texts = [
            "i went to chocolate @END@", "awesome is @END@",
            "the meaning of life @END@"
        ]
        from parseq.vocab import SequenceEncoder
        se = SequenceEncoder(tokenizer=lambda x: x.split())
        for t in texts:
            se.inc_build_vocab(t)
        se.finalize_vocab()
        x = BasicDecoderState(texts, texts, se, se)
        x.start_decoding()

        class Model(TransitionModel):
            def forward(self, x: BasicDecoderState):
                outprobs = torch.randn(len(x),
                                       x.query_encoder.vocab.number_of_ids())
                outprobs = torch.nn.functional.log_softmax(outprobs, -1)
                return outprobs, x

        model = Model()

        beamsize = 50
        maxtime = 10
        beam_xs = [
            x.make_copy(detach=False, deep=True) for _ in range(beamsize)
        ]
        beam_states = BeamState(beam_xs)

        print(len(beam_xs))
        print(len(beam_states))

        bt = BeamTransition(model, beamsize, maxtime=maxtime)
        i = 0
        _, _, y, _ = bt(x, i)
        i += 1
        _, _, y, _ = bt(y, i)

        all_terminated = y.all_terminated()
        while not all_terminated:
            _, predactions, y, all_terminated = bt(y, i)
            i += 1

        print("timesteps done:")
        print(i)
        print(y)
        print(predactions[0])
        for i in range(beamsize):
            print("-")
            # print(y.bstates[0].get(i).followed_actions)
            # print(predactions[0, i, :])
            pa = predactions[0, i, :]
            # print((pa == se.vocab[se.vocab.endtoken]).cumsum(0))
            pa = ((pa == se.vocab[se.vocab.endtoken]).long().cumsum(0) <
                  1).long() * pa
            yb = y.bstates[0].get(i).followed_actions[0, :]
            yb = yb * (yb != se.vocab[se.vocab.endtoken]).long()
            print(pa)
            print(yb)
            self.assertTrue(torch.allclose(pa, yb))
Esempio n. 4
0
class BeamReranker(TransitionModel):
    def __init__(self, genmodel, scoremodel, beamsize=5, maxtime=100,
                 copy_deep=False, endid=3, use_beamcache=True, **kw):
        super(BeamReranker, self).__init__(**kw)
        self.genmodel = genmodel
        self.beammodel = BeamTransition(self.genmodel, beamsize=beamsize, maxtime=maxtime, copy_deep=copy_deep)
        self.scoremodel = scoremodel
        self.endid = endid
        self._beamcache_actions = {}
        self._beamcache_attentions = {}         # TODO: cache attentions too
        self._beamcache_complete = False
        self._use_beamcache = use_beamcache

    def finalize_beamcache(self):
        numex = max(self._beamcache_actions.keys()) + 1
        ex = self._beamcache_actions[0]
        beamsize, seqlen = ex.size(1), ex.size(2)
        beamcache = torch.zeros(numex, beamsize, seqlen, dtype=torch.long, device=ex.device)
        for k in self._beamcache_actions:
            beamcache[k] = self._beamcache_actions[k]
        self._beamcache_actions = beamcache
        self._beamcache_complete = True

    def forward(self, x:State):
        _x = x
        # run genmodel in beam decoder in non-training mode
        if self._beamcache_complete and "eids" in x and self._use_beamcache:
            eids = torch.tensor(x.eids, device=self._beamcache_actions.device)
            predactions = self._beamcache_actions[eids]
        else:
            with torch.no_grad():
                self.beammodel.eval()
                x.start_decoding()
                i = 0
                all_terminated = x.all_terminated()
                while not all_terminated:
                    outprobs, predactions, x, all_terminated = self.beammodel(x, timestep=i)
                    i += 1
                if not self._beamcache_complete  and "eids" in _x and self._use_beamcache:
                    for j, eid in enumerate(list(_x.eids)):
                        self._beamcache_actions[eid] = predactions[j:j+1]

        # if training, add gold to beam
        if self.training:
            golds = _x.get_gold()
            # align gold dims
            pass    # TODO
Esempio n. 5
0
    def test_beam_search_vs_greedy(self):
        with torch.no_grad():
            texts = ["a b"] * 10
            from parseq.vocab import SequenceEncoder
            se = SequenceEncoder(tokenizer=lambda x: x.split())
            for t in texts:
                se.inc_build_vocab(t)
            se.finalize_vocab()
            x = BasicDecoderState(texts, texts, se, se)
            x.start_decoding()

            class Model(TransitionModel):
                transition_tensor = torch.tensor([[0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .01, .99]])

                def forward(self, x: BasicDecoderState):
                    prev = x.prev_actions
                    outprobs = self.transition_tensor[prev]
                    outprobs = torch.log(outprobs)
                    return outprobs, x

            model = Model()

            beamsize = 50
            maxtime = 10
            beam_xs = [
                x.make_copy(detach=False, deep=True) for _ in range(beamsize)
            ]
            beam_states = BeamState(beam_xs)

            print(len(beam_xs))
            print(len(beam_states))

            bt = BeamTransition(model, beamsize, maxtime=maxtime)
            i = 0
            _, _, y, _ = bt(x, i)
            i += 1
            _, _, y, _ = bt(y, i)

            all_terminated = y.all_terminated()
            while not all_terminated:
                start_time = time.time()
                _, _, y, all_terminated = bt(y, i)
                i += 1
                # print(i)
                end_time = time.time()
                print(f"{i}: {end_time - start_time}")

            print(y)
            print(y.bstates.get(0).followed_actions)
Esempio n. 6
0
    def test_beam_search_stored_probs(self):
        with torch.no_grad():
            texts = ["a b"] * 2
            from parseq.vocab import SequenceEncoder
            se = SequenceEncoder(tokenizer=lambda x: x.split())
            for t in texts:
                se.inc_build_vocab(t)
            se.finalize_vocab()
            x = BasicDecoderState(texts, texts, se, se)
            x.start_decoding()

            class Model(TransitionModel):
                transition_tensor = torch.tensor([[0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .01, .98]])

                def forward(self, x: BasicDecoderState):
                    prev = x.prev_actions
                    outprobs = self.transition_tensor[prev]
                    outprobs = torch.log(outprobs)
                    outprobs -= 0.01 * torch.rand_like(outprobs)
                    return outprobs, x

            model = Model()

            beamsize = 3
            maxtime = 10
            beam_xs = [
                x.make_copy(detach=False, deep=True) for _ in range(beamsize)
            ]
            beam_states = BeamState(beam_xs)

            print(len(beam_xs))
            print(len(beam_states))

            bt = BeamTransition(model, beamsize, maxtime=maxtime)
            i = 0
            _, _, y, _ = bt(x, i)
            i += 1
            _, _, y, _ = bt(y, i)

            all_terminated = False
            while not all_terminated:
                start_time = time.time()
                _, _, y, all_terminated = bt(y, i)
                i += 1
                # print(i)
                end_time = time.time()
                print(f"{i}: {end_time - start_time}")

            print(y)
            print(y.bstates.get(0).followed_actions)

            best_actions = y.bstates.get(0).followed_actions
            best_actionprobs = y.actionprobs.get(0)
            for i in range(len(best_actions)):
                print(i)
                i_prob = 0
                for j in range(len(best_actions[i])):
                    action_id = best_actions[i, j]
                    action_prob = best_actionprobs[i, j, action_id]
                    i_prob += action_prob
                print(i_prob)
                print(y.bscores[i, 0])
                self.assertTrue(torch.allclose(i_prob, y.bscores[i, 0]))