Beispiel #1
0
    def eval_step(self, batch, decoding_strategy='score', dump=False):
        xs, ys, use_packed = batch.text_vecs, batch.label_vecs, batch.use_packed
        xs_lens, ys_lens = batch.text_lens, batch.label_lens

        self.eval_mode()
        encoder_states = self.encoder(xs, xs_lens, use_packed=use_packed)

        if decoding_strategy == 'score':
            assert ys is not None
            _ = self.compute_loss(encoder_states, xs_lens, ys)

        if decoding_strategy == 'greedy':
            scores, preds, attn_w_log = self.decode_greedy(
                encoder_states, batch.text_vecs.size(0))
            preds = torch.stack(preds, dim=1)
            scores = torch.stack(scores, dim=1)
            #import ipdb; ipdb.set_trace()
            pred_lengths = (scores < 0).sum(dim=1).to(scores.device)
            length_penalties = torch.Tensor([
                Beam.get_length_penalty(i) for i in pred_lengths.tolist()
            ]).to(scores.device)
            scores_length_penalized = scores.sum(dim=1) / length_penalties
            pred_scores = tuple(
                (p, s) for p, s in zip(preds, scores_length_penalized))
            if dump is True:
                _dump = [attn_w_log]
                return pred_scores, _dump
            else:
                return pred_scores

        if 'beam' in decoding_strategy:
            beams = self.decode_beam(int(decoding_strategy.split(':')[-1]),
                                     len(batch.text_lens), encoder_states)
            pred_scores = beams
            return pred_scores