예제 #1
0
 def beam_search_pick(prime, width):
     """Returns the beam search pick."""
     if not len(prime) or prime == ' ':
         prime = random.choice(list(vocab.keys()))
     prime_labels = [vocab.get(word, 0) for word in prime.split()]
     bs = BeamSearch(beam_search_predict,
                     sess.run(self.cell.zero_state(1, tf.float32)),
                     prime_labels)
     samples, scores = bs.search(None, None, k=width, maxsample=num)
     return samples[np.argmin(scores)]
예제 #2
0
 def beam_search_pick(prime, width):
     """Returns the beam search pick."""
     if not len(prime) or prime == ' ':
         prime = random.choice(list(vocab.keys()))
     prime_labels = [vocab.get(word, 0) for word in prime.split()]
     bs = BeamSearch(beam_search_predict,
                     sess.run(self.cell.zero_state(1, tf.float32)),
                     prime_labels)
     samples, scores = bs.search(None, None, k=width, maxsample=num)
     return samples[np.argmin(scores)]
예제 #3
0
 def beam_search_pick(state, prime, rhyme_word, width):
     """Returns the beam search pick."""
     prime_labels = [vocab.get(word, 0) for word in prime.split()]
     if state == None:
         state = sess.run(self.cell.zero_state(1, tf.float32))
     bs = BeamSearch(words, beam_search_predict,
                     state,
                     prime_labels)
     samples, scores, states = bs.search(None, None, rhyme_word, k=width, maxsample=num)
     return samples[np.argmin(scores)], states[np.argmin(scores)]
예제 #4
0
    def __init__(self, loader, beam_search=False, beam_width=4, batch_size=64):
        self.batch_size = batch_size
        self.loader = loader

        # Dumping essential params
        self.word2idx = loader.corpus.trg_params['word2idx']
        self.idx2word = loader.corpus.trg_params['idx2word']
        self.sos = loader.corpus.trg_params['word2idx']['<s>']

        self.beam_search = None
        if beam_search:
            self.beam_search = BeamSearch(self.word2idx, beam_width=beam_width)
예제 #5
0
    def test_multiple_beams(self):
        bs = BeamSearch(naive_predict, self.initial_state, self.prime_labels)
        samples, scores = bs.search(None, None, k=4, maxsample=5)

        self.assertIn([0, 1, 4, 4, 4], samples)

        # All permutations of this form must be in the results.
        self.assertIn([0, 1, 4, 4, 3], samples)
        self.assertIn([0, 1, 4, 3, 4], samples)
        self.assertIn([0, 1, 3, 4, 4], samples)

        # Make sure that the best beam has the lowest score.
        self.assertEqual(samples[np.argmin(scores)], [0, 1, 4, 4, 4])
예제 #6
0
    def test_multiple_beams(self):
        bs = BeamSearch(naive_predict, self.initial_state, self.prime_labels)
        samples, scores = bs.search(None, None, k=4, maxsample=5)

        self.assertIn([0, 1, 4, 4, 4], samples)

        # All permutations of this form must be in the results.
        self.assertIn([0, 1, 4, 4, 3], samples)
        self.assertIn([0, 1, 4, 3, 4], samples)
        self.assertIn([0, 1, 3, 4, 4], samples)

        # Make sure that the best beam has the lowest score.
        self.assertEqual(samples[np.argmin(scores)], [0, 1, 4, 4, 4])
예제 #7
0
        def beam_search_pick(prime_labels,
                             width,
                             initial_state,
                             tokens=False,
                             attention_key_words=None,
                             keywords_count=None):
            """Returns the beam search pick."""

            bs = BeamSearch(beam_search_predict, initial_state, prime_labels,
                            attention_key_words, keywords_count)
            eos = vocab.get('</s>', 0) if tokens else None
            oov = vocab.get('<unk>', None)
            samples, scores = bs.search(oov, eos, k=width, maxsample=num)
            # returning the best sequence
            return samples[np.argmin(scores)]
예제 #8
0
 def beam_search_pick(weights):
     probs[0] = weights
     samples, scores = BeamSearch(probs).beamsearch(None, vocab.get(prime), None, 2, len(weights), False)
     sampleweights = samples[np.argmax(scores)]
     t = np.cumsum(sampleweights)
     s = np.sum(sampleweights)
     return(int(np.searchsorted(t, np.random.rand(1)*s)))
예제 #9
0
 def create_bs(prime):
     """Returns the beam search pick."""
     if not len(prime) or prime == ' ':
         prime = random.choice(list(vocab.keys())) # pick a random prime word if needed
     prime_labels = [vocab.get(word, 0) for word in prime.split()] # tokenize prime words
     bs = BeamSearch(beam_search_predict,
                     sess.run(self.cell.zero_state(1, tf.float32)), # reset state?
                     prime_labels) # pass labels?
     return bs, prime_labels
예제 #10
0
class MetricEvaluator(object):
    def __init__(self, loader, beam_search=False, beam_width=4, batch_size=64):
        self.batch_size = batch_size
        self.loader = loader

        # Dumping essential params
        self.word2idx = loader.corpus.trg_params['word2idx']
        self.idx2word = loader.corpus.trg_params['idx2word']
        self.sos = loader.corpus.trg_params['word2idx']['<s>']

        self.beam_search = None
        if beam_search:
            self.beam_search = BeamSearch(self.word2idx, beam_width=beam_width)

    def compute_scores(self, model, split, compute_ppl=False):
        itr = self.loader.create_epoch_iterator(split, self.batch_size)
        model.eval()

        refs = []
        hyps = []
        costs = []
        for i, (src, src_lengths, trg) in tqdm(enumerate(itr)):
            if compute_ppl:
                loss = model.score(src, src_lengths, trg)
                costs.append(loss.data[0])

            if self.beam_search is None:
                out = model.inference(src, src_lengths, sos=self.sos)
                out = out.cpu().data.tolist()
            else:
                src = model.encoder(src, src_lengths)
                out = self.beam_search.search(model.decoder, src)

            trg = trg.cpu().data.tolist()

            for ref, hyp in zip(trg, out):
                refs.append(self.loader.corpus.idx2sent(self.idx2word, ref))
                hyps.append(self.loader.corpus.idx2sent(self.idx2word, hyp))

        score = compute_bleu(refs, hyps)
        return score, costs
예제 #11
0
    def _test(self, batch_size):

        self.load_model()
        assert isinstance(self._model, torch.nn.Module), (
            "Before calling _validate, you must supply a PyTorch model using the"
            " `Trainer._set_model` method")
        test_iterator = self._datamaker.get_iterator("test",
                                                     batch_size,
                                                     device=self._device)

        generations = []
        targets = []
        sources = []
        words = []

        ppl = 0
        kld = 0
        for i, batch in enumerate(
                tqdm.tqdm(test_iterator,
                          desc=f"Testing (Epoch {self._validation_steps}): ")):
            try:
                self._model.zero_grad()
                self._model.eval()

                example, example_lens = batch.example
                definition, definition_lens = batch.definition
                word, word_lens = batch.word
                if self._model.variational or self._model.defbert:
                    definition_ae, definition_ae_lens = batch.definition_ae
                else:
                    definition_ae, definition_ae_lens = None, None

                sentence_mask = bert_dual_sequence_mask(
                    example,
                    self._datamaker.vocab.example.encode("</s>")[1:-1])
                current_batch_size = word.shape[0]

                decode_strategy = BeamSearch(
                    self._beam_size,
                    current_batch_size,
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    n_best=1 if self._n_best is None else self._n_best,
                    global_scorer=self._model.global_scorer,
                    min_length=self._min_length,
                    max_length=self._max_length,
                    return_attention=False,
                    block_ngram_repeat=3,
                    exclusion_tokens=self._exclusion_idxs,
                    stepwise_penalty=None,
                    ratio=self._ratio if self._ratio is not None else 0,
                )
                with torch.no_grad():
                    model_out = self._forward(
                        "test",
                        input=example,
                        seq_lens=example_lens,
                        span_token_ids=word,
                        target=definition,
                        target_lens=definition_lens,
                        decode_strategy=decode_strategy,
                        definition=definition_ae,
                        definition_lens=definition_ae_lens,
                        sentence_mask=sentence_mask,
                    )
                torch.cuda.empty_cache()

                generations.extend([
                    self._datamaker.decode(gen[0], "definition", batch=False)
                    for gen in model_out.predictions
                ])
                targets.extend(
                    self._datamaker.decode(definition,
                                           "definition",
                                           batch=True))
                sources.extend(
                    self._datamaker.decode(example, "example", batch=True))
                words.extend(self._datamaker.decode(word, "word", batch=True))

                ppl += model_out.perplexity.item()
                if model_out.kl is not None:
                    kld += model_out.kl.item()
                    self._TB_validation_log.add_scalar(
                        "kl", model_out.kl.item(), self._validation_counter)

                current_bleu = batch_bleu(
                    targets[-current_batch_size:],
                    generations[-current_batch_size:],
                    reduction="average",
                )
            except RuntimeError as e:
                # catch out of memory exceptions during fwd/bck (skip batch)
                if "out of memory" in str(e):
                    logging.warning(
                        "| WARNING: ran out of memory, skipping batch. "
                        "if this happens frequently, decrease batch_size or "
                        "truncate the inputs to the model.")
                    torch.cuda.empty_cache()
                    continue
                else:
                    raise e

        torch.cuda.empty_cache()

        bleu = batch_bleu(targets, generations, reduction="average")

        ppl = ppl / len(test_iterator)
        if self._model.variational:
            kld = kld / len(test_iterator)
            kld_best, kld_patience = self._update_metric_history(
                self._validation_steps,
                "KL Divergence",
                kld,
                self._test_metric_infos,
                metric_decreases=False,
            )

        metric_dict = {"bleu": bleu, "perplexity": ppl, "kl": kld}

        try:
            P, R, F1 = bert_score(generations, targets)
        except:
            P, R, F1 = (torch.Tensor([0]), torch.Tensor([0]), torch.Tensor([0
                                                                            ]))

        bleu_best, bleu_patience = self._update_metric_history(
            self._validation_steps,
            "bleu",
            bleu,
            self._test_metric_infos,
            metric_decreases=False,
        )

        ppl_best, ppl_patience = self._update_metric_history(
            self._validation_steps,
            "perplexity",
            ppl,
            self._test_metric_infos,
            metric_decreases=True,
        )

        bert_score_p_best, bert_score_p_patience = self._update_metric_history(
            self._validation_steps,
            "bert_score_p",
            P.mean().item(),
            self._test_metric_infos,
            metric_decreases=False,
        )

        bert_score_r_best, bert_score_r_patience = self._update_metric_history(
            self._validation_steps,
            "bert_score_r",
            R.mean().item(),
            self._test_metric_infos,
            metric_decreases=False,
        )
        bert_score_f1_best, bert_score_f1_patience = self._update_metric_history(
            self._validation_steps,
            "bert_score_f1",
            F1.mean().item(),
            self._test_metric_infos,
            metric_decreases=False,
        )
        # kld_best, kld_patience = self._update_metric_history(
        #     self._epoch_steps,
        #     "KL Divergence",
        #     kld,
        #     self._test_metric_infos,
        #     metric_decreases=False,
        # )

        self._test_write_metric_info()

        with open(
                os.path.join(self._validation_log_dir,
                             f"test_iter_{self._epoch_steps}.json"),
                "w",
        ) as f:
            f.write("\n".join([
                json.dumps({
                    "src": sources[i],
                    "tgt": targets[i],
                    "gen": generations[i],
                    "word": words[i],
                }) for i in range(len(generations))
            ]))

        return DotMap({"src": sources, "tgt": targets, "gen": generations})
예제 #12
0
    def _validate(self, batch_size):

        assert isinstance(self._model, torch.nn.Module), (
            "Before calling _validate, you must supply a PyTorch model using the"
            " `Trainer._set_model` method")
        valid_iterator = self._datamaker.get_iterator("valid",
                                                      batch_size,
                                                      device=self._device)

        generations = []
        targets = []
        sources = []
        words = []

        self._validation_steps += 1
        ppl = 0
        kld = 0
        for i, batch in enumerate(
                tqdm.tqdm(
                    valid_iterator,
                    desc=f"Validating (Epoch {self._validation_steps}): ")):
            try:
                self._validation_counter += 1
                self._model.zero_grad()
                self._model.eval()

                example, example_lens = batch.example
                definition, definition_lens = batch.definition
                word, word_lens = batch.word
                if self._model.variational:
                    definition_ae, definition_ae_lens = batch.definition_ae
                else:
                    definition_ae, definition_ae_lens = None, None

                sentence_mask = bert_dual_sequence_mask(
                    example,
                    self._datamaker.vocab.example.encode("</s>")[1:-1])
                current_batch_size = word.shape[0]

                decode_strategy = BeamSearch(
                    self._beam_size,
                    current_batch_size,
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    n_best=1 if self._n_best is None else self._n_best,
                    global_scorer=self._model.global_scorer,
                    min_length=self._min_length,
                    max_length=self._max_length,
                    return_attention=False,
                    block_ngram_repeat=3,
                    exclusion_tokens=self._exclusion_idxs,
                    stepwise_penalty=None,
                    ratio=self._ratio if self._ratio is not None else 0,
                )
                with torch.no_grad():
                    model_out = self._forward(
                        "valid",
                        input=example,
                        seq_lens=example_lens,
                        span_token_ids=word,
                        target=definition,
                        target_lens=definition_lens,
                        decode_strategy=decode_strategy,
                        definition=definition_ae,
                        definition_lens=definition_ae_lens,
                        sentence_mask=sentence_mask,
                    )
                torch.cuda.empty_cache()

                generations.extend([
                    self._datamaker.decode(gen[0], "definition", batch=False)
                    for gen in model_out.predictions
                ])
                targets.extend(
                    self._datamaker.decode(definition,
                                           "definition",
                                           batch=True))
                sources.extend(
                    self._datamaker.decode(example, "example", batch=True))
                words.extend(self._datamaker.decode(word, "word", batch=True))

                if torch.isnan(model_out.perplexity):
                    tqdm.tqdm.write(
                        "NaN Fouuuuuuuuuund!!!!!!!!!!!!!!! If this happens too often,"
                        " check WTF is going on")
                    continue
                ppl += model_out.perplexity.item()

                self._TB_validation_log.add_scalar(
                    "batch_perplexity",
                    model_out.perplexity.item(),
                    self._validation_counter,
                )

                current_bleu = batch_bleu(
                    targets[-current_batch_size:],
                    generations[-current_batch_size:],
                    reduction="average",
                )
                self._TB_validation_log.add_scalar("batch_BLEU", current_bleu,
                                                   self._validation_counter)

                if model_out.kl is not None:
                    kld += model_out.kl.item()
                    self._TB_validation_log.add_scalar(
                        "kl", model_out.kl.item(), self._validation_counter)
                if self._val_data_limit:
                    if i * batch_size > self._val_data_limit:
                        break
            except RuntimeError as e:
                # catch out of memory exceptions during fwd/bck (skip batch)
                if "out of memory" in str(e):
                    logging.warning(
                        "| WARNING: ran out of memory, skipping batch. "
                        "if this happens frequently, decrease batch_size or "
                        "truncate the inputs to the model.")
                    continue
                else:
                    raise e

        torch.cuda.empty_cache()

        bleu = batch_bleu(targets, generations, reduction="average")
        self._TB_validation_log.add_scalar("BLEU", bleu,
                                           self._validation_steps)
        try:
            P, R, F1 = bert_score(generations, targets)
        except:
            P, R, F1 = (torch.Tensor([0]), torch.Tensor([0]), torch.Tensor([0
                                                                            ]))

        self._TB_validation_log.add_scalar("bert-score-p",
                                           P.mean().item(),
                                           self._validation_counter)
        self._TB_validation_log.add_scalar("bert-score-r",
                                           R.mean().item(),
                                           self._validation_counter)
        self._TB_validation_log.add_scalar("bert-score-f1",
                                           F1.mean().item(),
                                           self._validation_counter)
        ppl = ppl / len(valid_iterator)
        # kld = kld / len(valid_iterator)
        self._TB_validation_log.add_scalar("kl", kld, self._validation_counter)

        # Had to do this for memory issues

        self._TB_validation_log.add_scalar("Perplexity", ppl,
                                           self._validation_steps)

        metric_dict = {"bleu": bleu, "perplexity": ppl, "kl": kld}

        bleu_best, bleu_patience = self._update_metric_history(
            self._validation_steps,
            "bleu",
            bleu,
            self._metric_infos,
            metric_decreases=False,
        )

        ppl_best, ppl_patience = self._update_metric_history(
            self._validation_steps,
            "perplexity",
            ppl,
            self._metric_infos,
            metric_decreases=True,
        )

        bert_score_p_best, bert_score_p_patience = self._update_metric_history(
            self._validation_steps,
            "bert_score_p",
            P.mean().item(),
            self._metric_infos,
            metric_decreases=False,
        )

        bert_score_r_best, bert_score_r_patience = self._update_metric_history(
            self._validation_steps,
            "bert_score_r",
            R.mean().item(),
            self._metric_infos,
            metric_decreases=False,
        )
        bert_score_f1_best, bert_score_f1_patience = self._update_metric_history(
            self._validation_steps,
            "bert_score_f1",
            F1.mean().item(),
            self._metric_infos,
            metric_decreases=False,
        )
        # kld_best, kld_patience = self._update_metric_history(
        #    self._epoch_steps,
        #    "KL Divergence",
        #    kld,
        #    self._metric_infos,
        #    metric_decreases=True,
        # )

        if self._keep_all_checkpoints:
            torch.save(
                self._model.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "model",
                    f"iter_{self._validation_steps}.pth",
                ),
            )
            torch.save(
                self._optimizer.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "optimizer",
                    f"iter_{self._validation_steps}.pth",
                ),
            )
        if bleu_best:
            torch.save(
                self._model.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "model",
                    f"bleu_best.pth",
                ),
            )
            torch.save(
                self._optimizer.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "optimizer",
                    f"bleu_best.pth",
                ),
            )
            self._bad_epochs = 0

        self._write_metric_info()

        if ppl_best:
            torch.save(
                self._model.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "model",
                    f"ppl_best.pth",
                ),
            )
            torch.save(
                self._optimizer.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "optimizer",
                    f"ppl_best.pth",
                ),
            )
            self._bad_epochs = 0
        if bert_score_f1_best:
            torch.save(
                self._model.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "model",
                    f"bert_score_best.pth",
                ),
            )
            torch.save(
                self._optimizer.state_dict(),
                os.path.join(
                    self._serialization_dir,
                    "optimizer",
                    f"bert_score_best.pth",
                ),
            )
            self._bad_epochs = 0
        if not bleu_best and not ppl_best and not bert_score_f1_best:
            self._bad_epochs += 1
        if bleu_patience and ppl_patience and bert_score_f1_patience:
            logging.info(
                "Ran out of patience for both BLEU, BERTScore and perplexity. Stopping"
                " training")

            self._patience_exceeded = True
        with open(
                os.path.join(self._validation_log_dir,
                             f"iter_{self._epoch_steps}.json"),
                "w",
        ) as f:
            f.write("\n".join([
                json.dumps({
                    "src": sources[i],
                    "tgt": targets[i],
                    "gen": generations[i],
                    "word": words[i],
                }) for i in range(len(generations))
            ]))

        if self._bad_epochs == 3:
            tqdm.tqdm.write(
                "4 Bad Epochs in a row, reverting optimizer and model to the last"
                " bert_score_best.pth")
            self._model.load_state_dict(
                torch.load(
                    os.path.join(self._serialization_dir, "model",
                                 f"bert_score_best.pth")), )
            self._optimizer.load_state_dict(
                torch.load(
                    os.path.join(self._serialization_dir, "optimizer",
                                 f"bert_score_best.pth")))
            for param_group in self._optimizer.param_groups:
                param_group["lr"] *= 0.5

        return DotMap({"src": sources, "tgt": targets, "gen": generations})
예제 #13
0
 def test_single_beam(self):
     bs = BeamSearch(naive_predict, self.initial_state, self.prime_labels)
     samples, scores = bs.search(None, None, k=1, maxsample=5)
     self.assertEqual(samples, [[0, 1, 4, 4, 4]])
예제 #14
0
 def test_single_beam(self):
     bs = BeamSearch(naive_predict, self.initial_state, self.prime_labels)
     samples, scores = bs.search(None, None, k=1, maxsample=5)
     self.assertEqual(samples, [[0, 1, 4, 4, 4]])