Ejemplo n.º 1
0
 def translate_valid():
     translations_fn = output_files_dict["valid_translation_output"]  # save_prefix + ".test.out"
     control_src_fn = output_files_dict["valid_src_output"]  # save_prefix + ".test.src.out"
     return translate_to_file(encdec, eos_idx, valid_src_data, mb_size, tgt_indexer,
                              translations_fn, test_references=valid_references, control_src_fn=control_src_fn,
                              src_indexer=src_indexer, gpu=gpu, nb_steps=50, reverse_src=reverse_src, reverse_tgt=reverse_tgt,
                              s_unk_tag=s_unk_tag, t_unk_tag=t_unk_tag)
Ejemplo n.º 2
0
    def __call__(self, trainer):
        encdec = trainer.updater.get_optimizer("main").target
        #         translations_fn = output_files_dict["dev_translation_output"] #save_prefix + ".test.out"
        #         control_src_fn = output_files_dict["dev_src_output"] #save_prefix + ".test.src.out"
        bleu_stats = translate_to_file(encdec,
                                       self.eos_idx,
                                       self.src_data,
                                       self.mb_size,
                                       self.tgt_indexer,
                                       self.translations_fn,
                                       test_references=self.references,
                                       control_src_fn=self.control_src_fn,
                                       src_indexer=self.src_indexer,
                                       gpu=self.gpu,
                                       nb_steps=50,
                                       reverse_src=self.reverse_src,
                                       reverse_tgt=self.reverse_tgt,
                                       s_unk_tag=self.s_unk_tag,
                                       t_unk_tag=self.t_unk_tag)
        bleu = bleu_stats.bleu()
        chainer.reporter.report({
            self.observation_name: bleu,
            self.observation_name + "_details": repr(bleu)
        })

        if self.best_bleu is None or self.best_bleu < bleu:
            log.info("%s improvement: %r -> %r" %
                     (self.observation_name, self.best_bleu, bleu))
            self.best_bleu = bleu
            if self.save_best_model_to is not None:
                log.info("saving best bleu (%s) model to %s" % (
                    self.observation_name,
                    self.save_best_model_to,
                ))
                serializers.save_npz(self.save_best_model_to, encdec)
                if self.config_training is not None:
                    config_session = self.config_training.copy(readonly=False)
                    config_session.add_section("model_parameters",
                                               keep_at_bottom="metadata")
                    config_session["model_parameters"][
                        "filename"] = self.save_best_model_to
                    config_session["model_parameters"]["type"] = "model"
                    config_session["model_parameters"][
                        "description"] = "best_bleu"
                    config_session["model_parameters"][
                        "infos"] = argument_parsing_tools.OrderedNamespace()
                    config_session["model_parameters"]["infos"][
                        "bleu_stats"] = str(bleu_stats)
                    config_session["model_parameters"]["infos"][
                        "iteration"] = trainer.updater.iteration
                    config_session.set_metadata_modified_time()
                    config_session.save_to(self.save_best_model_to + ".config")


# json.dump(config_session, open(self.save_best_model_to + ".config",
# "w"), indent=2, separators=(',', ': '))
        else:
            log.info("no bleu (%s) improvement: %f >= %f" %
                     (self.observation_name, self.best_bleu, bleu))