Beispiel #1
0
    def __call__(self, trainer):
        encdec = trainer.updater.get_optimizer("main").target
        log.info("computing %s" % self.observation_name)
        dev_loss = compute_loss_all(encdec, self.data, self.eos_idx, self.mb_size,
                                    gpu=self.gpu,
                                    reverse_src=self.reverse_src, reverse_tgt=self.reverse_tgt,
                                    use_chainerx=self.use_chainerx)
        log.info("%s: %f (current best: %r)" % (self.observation_name, dev_loss, self.best_loss))
        chainer.reporter.report({self.observation_name: dev_loss})

        if self.best_loss is None or self.best_loss > dev_loss:
            log.info("loss (%s) improvement: %r -> %r" % (self.observation_name,
                                                          self.best_loss, dev_loss))
            self.best_loss = dev_loss
            if self.save_best_model_to is not None:
                log.info("saving best loss (%s) model to %s" % (self.observation_name, self.save_best_model_to,))
                serializers.save_npz(self.save_best_model_to, encdec)
                if self.config_training is not None:
                    config_session = self.config_training.copy(readonly=False)
                    config_session.add_section("model_parameters", keep_at_bottom="metadata")
                    config_session["model_parameters"]["filename"] = self.save_best_model_to
                    config_session["model_parameters"]["type"] = "model"
                    config_session["model_parameters"]["description"] = "best_loss"
                    config_session["model_parameters"]["infos"] = argument_parsing_tools.OrderedNamespace()
                    config_session["model_parameters"]["infos"]["loss"] = float(dev_loss)
                    config_session["model_parameters"]["infos"]["iteration"] = trainer.updater.iteration
                    config_session.set_metadata_modified_time()
                    config_session.save_to(self.save_best_model_to + ".config")
Beispiel #2
0
    def __call__(self, trainer):
        encdec = trainer.updater.get_optimizer("main").target
        #         translations_fn = output_files_dict["dev_translation_output"] #save_prefix + ".test.out"
        #         control_src_fn = output_files_dict["dev_src_output"] #save_prefix + ".test.src.out"
        bleu_stats = translate_to_file(encdec,
                                       self.eos_idx,
                                       self.src_data,
                                       self.mb_size,
                                       self.tgt_indexer,
                                       self.translations_fn,
                                       test_references=self.references,
                                       control_src_fn=self.control_src_fn,
                                       src_indexer=self.src_indexer,
                                       gpu=self.gpu,
                                       nb_steps=50,
                                       reverse_src=self.reverse_src,
                                       reverse_tgt=self.reverse_tgt,
                                       s_unk_tag=self.s_unk_tag,
                                       t_unk_tag=self.t_unk_tag)
        bleu = bleu_stats.bleu()
        chainer.reporter.report({
            self.observation_name: bleu,
            self.observation_name + "_details": repr(bleu)
        })

        if self.best_bleu is None or self.best_bleu < bleu:
            log.info("%s improvement: %r -> %r" %
                     (self.observation_name, self.best_bleu, bleu))
            self.best_bleu = bleu
            if self.save_best_model_to is not None:
                log.info("saving best bleu (%s) model to %s" % (
                    self.observation_name,
                    self.save_best_model_to,
                ))
                serializers.save_npz(self.save_best_model_to, encdec)
                if self.config_training is not None:
                    config_session = self.config_training.copy(readonly=False)
                    config_session.add_section("model_parameters",
                                               keep_at_bottom="metadata")
                    config_session["model_parameters"][
                        "filename"] = self.save_best_model_to
                    config_session["model_parameters"]["type"] = "model"
                    config_session["model_parameters"][
                        "description"] = "best_bleu"
                    config_session["model_parameters"][
                        "infos"] = argument_parsing_tools.OrderedNamespace()
                    config_session["model_parameters"]["infos"][
                        "bleu_stats"] = str(bleu_stats)
                    config_session["model_parameters"]["infos"][
                        "iteration"] = trainer.updater.iteration
                    config_session.set_metadata_modified_time()
                    config_session.save_to(self.save_best_model_to + ".config")


# json.dump(config_session, open(self.save_best_model_to + ".config",
# "w"), indent=2, separators=(',', ': '))
        else:
            log.info("no bleu (%s) improvement: %f >= %f" %
                     (self.observation_name, self.best_bleu, bleu))
Beispiel #3
0
def load_config_train(filename, readonly=True, no_error=False):
    config = argument_parsing_tools.OrderedNamespace.load_from(filename)
    if "metadata" not in config:  # older config file
        parse_option_orderer = get_parse_option_orderer()
        config_training = parse_option_orderer.convert_args_to_ordered_dict(
            config["command_line"], args_is_namespace=False)

        convert_cell_string(config_training, no_error=no_error)

        assert "data" not in config_training
        config_training["data"] = argument_parsing_tools.OrderedNamespace()
        config_training["data"]["data_fn"] = config["data"]
        config_training["data"]["Vi"] = config["Vi"]
        config_training["data"]["Vo"] = config["Vo"]
        config_training["data"]["voc"] = config["voc"]

        assert "metadata" not in config_training
        config_training["metadata"] = argument_parsing_tools.OrderedNamespace()
        config_training["metadata"]["config_version_num"] = 0.9
        config_training["metadata"]["command_line"] = None
        config_training["metadata"]["knmt_version"] = None
        config = config_training
    elif config["metadata"]["config_version_num"] != 1.0:
        raise ValueError(
            "The config version of %s is not supported by this version of the program"
            % filename)

    # Compatibility with intermediate verions of config file
    if "data_prefix" in config and "data_prefix" not in config[
            "training_management"]:
        config["training_management"]["data_prefix"] = config["data_prefix"]
        del config["data_prefix"]

    if "train_prefix" in config and "train_prefix" not in config[
            "training_management"]:
        config["training_management"]["train_prefix"] = config["train_prefix"]
        del config["train_prefix"]

    if readonly:
        config.set_readonly()
    return config
Beispiel #4
0
    def __call__(self, trainer):
        log.info("Saving current trainer state to file %s" % self.save_to)
        serializers.save_npz(self.save_to, trainer)
        config_session = self.config_training.copy(readonly=False)
        config_session.add_section("model_parameters", keep_at_bottom="metadata")
        config_session["model_parameters"]["filename"] = self.save_to
        config_session["model_parameters"]["type"] = "snapshot"
        config_session["model_parameters"]["description"] = "checkpoint"
        config_session["model_parameters"]["infos"] = argument_parsing_tools.OrderedNamespace()
        config_session["model_parameters"]["infos"]["iteration"] = trainer.updater.iteration
        config_session.set_metadata_modified_time()
        config_session.save_to(self.save_to + ".config")
# json.dump(config_session, open(self.save_to + ".config", "w"), indent=2,
# separators=(',', ': '))
        log.info("Saved trainer snapshot to file %s" % self.save_to)
Beispiel #5
0
def load_config_eval(filename, readonly=True):
    config = argument_parsing_tools.OrderedNamespace.load_from(filename)
    if "metadata" not in config:  # older config file
        parse_option_orderer = get_parse_option_orderer()
        config_eval = parse_option_orderer.convert_args_to_ordered_dict(config, args_is_namespace=False)
        assert "metadata" not in config_eval
        config_eval["metadata"] = argument_parsing_tools.OrderedNamespace()
        config_eval["metadata"]["config_version_num"] = 0.9
        config_eval["metadata"]["command_line"] = None
        config_eval["metadata"]["knmt_version"] = None
        config = config_eval
    elif config["metadata"]["config_version_num"] != 1.0:
        raise ValueError("The config version of %s is not supported by this version of the program" % filename)
    if readonly:
        config.set_readonly()
    return config
Beispiel #6
0
def train_on_data_chainer(
    encdec,
    optimizer,
    training_data,
    output_files_dict,
    src_indexer,
    tgt_indexer,
    eos_idx,
    config_training,
    stop_trigger=None,
    test_data=None,
    dev_data=None,
    valid_data=None,
):

    output_dir = config_training.training_management.save_prefix
    mb_size = config_training.training.mb_size
    nb_of_batch_to_sort = config_training.training.nb_batch_to_sort
    gpu = config_training.training_management.gpu
    report_every = config_training.training_management.report_every
    randomized = config_training.training.randomized_data
    reverse_src = config_training.training.reverse_src
    reverse_tgt = config_training.training.reverse_tgt
    do_not_save_data_for_resuming = config_training.training_management.no_resume
    noise_on_prev_word = config_training.training.noise_on_prev_word
    curiculum_training = config_training.training.curiculum_training
    use_previous_prediction = config_training.training.use_previous_prediction
    no_report_or_save = config_training.training_management.no_report_or_save
    use_memory_optimization = config_training.training_management.use_memory_optimization
    sample_every = config_training.training_management.sample_every
    use_reinf = config_training.training.use_reinf
    save_ckpt_every = config_training.training_management.save_ckpt_every
    trainer_snapshot = config_training.training_management.load_trainer_snapshot
    save_initial_model_to = config_training.training_management.save_initial_model_to
    reshuffle_every_epoch = config_training.training_management.reshuffle_every_epoch

    use_soft_prediction_feedback = config_training.training.use_soft_prediction_feedback
    use_gumbel_for_soft_predictions = config_training.training.use_gumbel_for_soft_predictions
    temperature_for_soft_predictions = config_training.training.temperature_for_soft_predictions

    generate_computation_graph = config_training.training_management.generate_computation_graph

    dynamic_batching = config_training.training.get("dynamic_batching", False)
    dynamic_batching_max_elems = config_training.training.get(
        "dynamic_batching_max_elems", 10000)
    dynamic_batching_nb_sent_to_sort = config_training.training.get(
        "dynamic_batching_nb_sent_to_sort", 5000)

    @chainer.training.make_extension()
    def sample_extension(trainer):
        encdec = trainer.updater.get_optimizer("main").target
        iterator = trainer.updater.get_iterator("main")
        mb_raw = iterator.peek()

        def s_unk_tag(num, utag):
            return "S_UNK_%i" % utag

        def t_unk_tag(num, utag):
            return "T_UNK_%i" % utag

        try:
            if encdec.encdec_type() == "ff":
                src_seqs, tgt_seqs = zip(*mb_raw)
                sample_once_ff(encdec,
                               src_seqs,
                               tgt_seqs,
                               src_indexer,
                               tgt_indexer,
                               max_nb=20,
                               s_unk_tag=s_unk_tag,
                               t_unk_tag=t_unk_tag)
            else:

                src_batch, tgt_batch, src_mask = make_batch_src_tgt(
                    mb_raw,
                    eos_idx=eos_idx,
                    padding_idx=0,
                    gpu=gpu,
                    need_arg_sort=False)

                sample_once(encdec,
                            src_batch,
                            tgt_batch,
                            src_mask,
                            src_indexer,
                            tgt_indexer,
                            eos_idx,
                            max_nb=20,
                            s_unk_tag=s_unk_tag,
                            t_unk_tag=t_unk_tag)
        except CudaException:
            log.warn("CUDARuntimeError during sample. Skipping sample")

    if dynamic_batching:
        log.info("using dynamic matching with %i %i",
                 dynamic_batching_max_elems, dynamic_batching_nb_sent_to_sort)
        iterator_training_data = DynamicLengthBasedSerialIterator(
            training_data,
            max_nb_elements=dynamic_batching_max_elems,
            nb_sent_sort=dynamic_batching_nb_sent_to_sort,
            sort_key=lambda x: len(x[1]),
            repeat=True,
            shuffle=reshuffle_every_epoch)

    else:
        iterator_training_data = LengthBasedSerialIterator(
            training_data,
            mb_size,
            nb_of_batch_to_sort=nb_of_batch_to_sort,
            sort_key=lambda x: len(x[0]),
            repeat=True,
            shuffle=reshuffle_every_epoch)

    generate_loss_computation_graph_on_first_call = [
        generate_computation_graph is not None
    ]

    if encdec.encdec_type() == "ff":

        def loss_func(src_seq, tgt_seq):

            t0 = time.clock()

            loss = encdec.compute_loss(src_seq, tgt_seq, reduce="no")
            total_loss = F.sum(loss)
            total_nb_predictions = sum(len(seq) + 1 for seq in tgt_seq)

            avg_loss = total_loss / total_nb_predictions

            t1 = time.clock()
            chainer.reporter.report({"forward_time": t1 - t0})

            chainer.reporter.report({"mb_loss": total_loss.data})
            chainer.reporter.report(
                {"mb_nb_predictions": total_nb_predictions})
            chainer.reporter.report({"trg_loss": avg_loss.data})

            log.info("batch infos: %i x [%i | %i]", len(src_seq),
                     max(len(s) for s in src_seq),
                     max(len(s) for s in tgt_seq))

            if generate_loss_computation_graph_on_first_call[0]:
                log.info("Writing loss computation graph to %s",
                         generate_computation_graph)
                import chainer.computational_graph as c
                g = c.build_computational_graph(
                    [avg_loss]
                )  #, variable_style=None, function_style=None, show_name=False )
                with open(generate_computation_graph, 'w') as o:
                    o.write(g.dump())
                generate_loss_computation_graph_on_first_call[0] = False

            return avg_loss

        def convert_mb(mb_raw, device):
            return tuple(zip(*mb_raw))
    else:

        def loss_func(src_batch, tgt_batch, src_mask):

            t0 = time.clock()
            (total_loss, total_nb_predictions), attn = encdec(
                src_batch,
                tgt_batch,
                src_mask,
                raw_loss_info=True,
                noise_on_prev_word=noise_on_prev_word,
                use_previous_prediction=use_previous_prediction,
                use_soft_prediction_feedback=use_soft_prediction_feedback,
                use_gumbel_for_soft_predictions=use_gumbel_for_soft_predictions,
                temperature_for_soft_predictions=
                temperature_for_soft_predictions)
            avg_loss = total_loss / total_nb_predictions

            t1 = time.clock()
            chainer.reporter.report({"forward_time": t1 - t0})

            chainer.reporter.report({"mb_loss": total_loss.data})
            chainer.reporter.report(
                {"mb_nb_predictions": total_nb_predictions})
            chainer.reporter.report({"trg_loss": avg_loss.data})

            log.info("batch infos: %i x [%i | %i]", src_batch[0].data.shape[0],
                     len(src_batch), len(tgt_batch))

            if generate_loss_computation_graph_on_first_call[0]:
                log.info("Writing loss computation graph to %s",
                         generate_computation_graph)
                import chainer.computational_graph as c
                g = c.build_computational_graph([avg_loss])
                with open(generate_computation_graph, 'w') as o:
                    o.write(g.dump())
                generate_loss_computation_graph_on_first_call[0] = False

            return avg_loss

        def convert_mb(mb_raw, device):
            return make_batch_src_tgt(mb_raw,
                                      eos_idx=eos_idx,
                                      padding_idx=0,
                                      gpu=device,
                                      need_arg_sort=False)

    updater = Updater(
        iterator_training_data,
        optimizer,
        converter=convert_mb,
        # iterator_training_data = chainer.iterators.SerialIterator(training_data, mb_size,
        # repeat = True,
        # shuffle = reshuffle_every_epoch)
        device=gpu,
        loss_func=loss_func,
        need_to_convert_to_variables=False)

    trainer = chainer.training.Trainer(updater, stop_trigger, out=output_dir)
    #     trainer.extend(chainer.training.extensions.LogReport(trigger=(10, 'iteration')))
    #     trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'iteration', 'trg_loss', "dev_loss", "dev_bleu"]),
    #                    trigger = (1, "iteration"))

    if dev_data is not None and not no_report_or_save:
        dev_loss_extension = ComputeLossExtension(
            dev_data,
            eos_idx,
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            save_best_model_to=output_files_dict["model_best_loss"],
            observation_name="dev_loss",
            config_training=config_training)
        trainer.extend(dev_loss_extension, trigger=(report_every, "iteration"))

        dev_bleu_extension = ComputeBleuExtension(
            dev_data,
            eos_idx,
            src_indexer,
            tgt_indexer,
            output_files_dict["dev_translation_output"],
            output_files_dict["dev_src_output"],
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            save_best_model_to=output_files_dict["model_best"],
            observation_name="dev_bleu",
            config_training=config_training)

        trainer.extend(dev_bleu_extension, trigger=(report_every, "iteration"))

    if test_data is not None and not no_report_or_save:
        test_loss_extension = ComputeLossExtension(
            test_data,
            eos_idx,
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            observation_name="test_loss")
        trainer.extend(test_loss_extension,
                       trigger=(report_every, "iteration"))

        test_bleu_extension = ComputeBleuExtension(
            test_data,
            eos_idx,
            src_indexer,
            tgt_indexer,
            output_files_dict["test_translation_output"],
            output_files_dict["test_src_output"],
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            observation_name="test_bleu")

        trainer.extend(test_bleu_extension,
                       trigger=(report_every, "iteration"))

    if not no_report_or_save:
        trainer.extend(sample_extension, trigger=(sample_every, "iteration"))

        # trainer.extend(chainer.training.extensions.snapshot(), trigger = (save_ckpt_every, "iteration"))

        trainer.extend(CheckpontSavingExtension(
            output_files_dict["model_ckpt"], config_training),
                       trigger=(save_ckpt_every, "iteration"))

        trainer.extend(
            SqliteLogExtension(db_path=output_files_dict["sqlite_db"]))

    trainer.extend(
        TrainingLossSummaryExtension(trigger=(report_every, "iteration")))

    if config_training.training_management.resume:
        if "model_parameters" not in config_training:
            log.error("cannot find model parameters in config file")
            raise ValueError(
                "Config file do not contain model_parameters section")
        if config_training.model_parameters.type == "snapshot":
            model_filename = config_training.model_parameters.filename
            log.info("resuming from trainer parameters %s" % model_filename)
            serializers.load_npz(model_filename, trainer)

    if trainer_snapshot is not None:
        log.info("loading trainer parameters from %s" % trainer_snapshot)
        serializers.load_npz(trainer_snapshot, trainer)

    try:
        if save_initial_model_to is not None:
            log.info("Saving initial parameters to %s" % save_initial_model_to)
            encdec = trainer.updater.get_optimizer("main").target
            serializers.save_npz(save_initial_model_to, encdec)

        trainer.run()
    except BaseException:
        if not no_report_or_save:
            final_snapshot_fn = output_files_dict["model_final"]
            log.info(
                "Exception met. Trying to save current trainer state to file %s"
                % final_snapshot_fn)
            serializers.save_npz(final_snapshot_fn, trainer)
            #             chainer.training.extensions.snapshot(filename = final_snapshot_fn)(trainer)
            config_session = config_training.copy(readonly=False)
            config_session.add_section("model_parameters",
                                       keep_at_bottom="metadata")
            config_session["model_parameters"]["filename"] = final_snapshot_fn
            config_session["model_parameters"]["type"] = "snapshot"
            config_session["model_parameters"]["description"] = "final"
            config_session["model_parameters"][
                "infos"] = argument_parsing_tools.OrderedNamespace()
            config_session["model_parameters"]["infos"][
                "iteration"] = trainer.updater.iteration
            config_session.set_metadata_modified_time()
            config_session.save_to(final_snapshot_fn + ".config")
            # json.dump(config_session, open(final_snapshot_fn + ".config", "w"),
            # indent=2, separators=(',', ': '))
            log.info("Saved trainer snapshot to file %s" % final_snapshot_fn)
        raise
Beispiel #7
0
def train_on_data_chainer(
    encdec,
    optimizer,
    training_data,
    output_files_dict,
    src_indexer,
    tgt_indexer,
    eos_idx,
    config_training,
    stop_trigger=None,
    test_data=None,
    dev_data=None,
    valid_data=None,
):

    output_dir = config_training.training_management.save_prefix
    mb_size = config_training.training.mb_size
    nb_of_batch_to_sort = config_training.training.nb_batch_to_sort
    gpu = config_training.training_management.gpu
    report_every = config_training.training_management.report_every
    randomized = config_training.training.randomized_data
    reverse_src = config_training.training.reverse_src
    reverse_tgt = config_training.training.reverse_tgt
    do_not_save_data_for_resuming = config_training.training_management.no_resume
    noise_on_prev_word = config_training.training.noise_on_prev_word
    curiculum_training = config_training.training.curiculum_training
    use_previous_prediction = config_training.training.use_previous_prediction
    no_report_or_save = config_training.training_management.no_report_or_save
    use_memory_optimization = config_training.training_management.use_memory_optimization
    sample_every = config_training.training_management.sample_every
    use_reinf = config_training.training.use_reinf
    save_ckpt_every = config_training.training_management.save_ckpt_every
    trainer_snapshot = config_training.training_management.load_trainer_snapshot
    save_initial_model_to = config_training.training_management.save_initial_model_to
    reshuffle_every_epoch = config_training.training_management.reshuffle_every_epoch

    use_soft_prediction_feedback = config_training.training.use_soft_prediction_feedback
    use_gumbel_for_soft_predictions = config_training.training.use_gumbel_for_soft_predictions
    temperature_for_soft_predictions = config_training.training.temperature_for_soft_predictions

    @chainer.training.make_extension()
    def sample_extension(trainer):
        encdec = trainer.updater.get_optimizer("main").target
        iterator = trainer.updater.get_iterator("main")
        mb_raw = iterator.peek()

        src_batch, tgt_batch, src_mask = make_batch_src_tgt(
            mb_raw,
            eos_idx=eos_idx,
            padding_idx=0,
            gpu=gpu,
            volatile="on",
            need_arg_sort=False)

        def s_unk_tag(num, utag):
            return "S_UNK_%i" % utag

        def t_unk_tag(num, utag):
            return "T_UNK_%i" % utag

        sample_once(encdec,
                    src_batch,
                    tgt_batch,
                    src_mask,
                    src_indexer,
                    tgt_indexer,
                    eos_idx,
                    max_nb=20,
                    s_unk_tag=s_unk_tag,
                    t_unk_tag=t_unk_tag)

    iterator_training_data = LengthBasedSerialIterator(
        training_data,
        mb_size,
        nb_of_batch_to_sort=nb_of_batch_to_sort,
        sort_key=lambda x: len(x[0]),
        repeat=True,
        shuffle=reshuffle_every_epoch)

    def loss_func(src_batch, tgt_batch, src_mask):

        t0 = time.clock()
        (total_loss, total_nb_predictions), attn = encdec(
            src_batch,
            tgt_batch,
            src_mask,
            raw_loss_info=True,
            noise_on_prev_word=noise_on_prev_word,
            use_previous_prediction=use_previous_prediction,
            mode="train",
            use_soft_prediction_feedback=use_soft_prediction_feedback,
            use_gumbel_for_soft_predictions=use_gumbel_for_soft_predictions,
            temperature_for_soft_predictions=temperature_for_soft_predictions)
        avg_loss = total_loss / total_nb_predictions

        t1 = time.clock()
        chainer.reporter.report({"forward_time": t1 - t0})

        chainer.reporter.report({"mb_loss": total_loss.data})
        chainer.reporter.report({"mb_nb_predictions": total_nb_predictions})
        chainer.reporter.report({"trg_loss": avg_loss.data})
        return avg_loss

    def convert_mb(mb_raw, device):
        return make_batch_src_tgt(mb_raw,
                                  eos_idx=eos_idx,
                                  padding_idx=0,
                                  gpu=device,
                                  volatile="off",
                                  need_arg_sort=False)

    updater = Updater(
        iterator_training_data,
        optimizer,
        converter=convert_mb,
        # iterator_training_data = chainer.iterators.SerialIterator(training_data, mb_size,
        # repeat = True,
        # shuffle = reshuffle_every_epoch)
        device=gpu,
        loss_func=loss_func,
        need_to_convert_to_variables=False)

    trainer = chainer.training.Trainer(updater, stop_trigger, out=output_dir)
    #     trainer.extend(chainer.training.extensions.LogReport(trigger=(10, 'iteration')))
    #     trainer.extend(chainer.training.extensions.PrintReport(['epoch', 'iteration', 'trg_loss', "dev_loss", "dev_bleu"]),
    #                    trigger = (1, "iteration"))

    if dev_data is not None and not no_report_or_save:
        dev_loss_extension = ComputeLossExtension(
            dev_data,
            eos_idx,
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            save_best_model_to=output_files_dict["model_best_loss"],
            observation_name="dev_loss",
            config_training=config_training)
        trainer.extend(dev_loss_extension, trigger=(report_every, "iteration"))

        dev_bleu_extension = ComputeBleuExtension(
            dev_data,
            eos_idx,
            src_indexer,
            tgt_indexer,
            output_files_dict["dev_translation_output"],
            output_files_dict["dev_src_output"],
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            save_best_model_to=output_files_dict["model_best"],
            observation_name="dev_bleu",
            config_training=config_training)

        trainer.extend(dev_bleu_extension, trigger=(report_every, "iteration"))

    if test_data is not None and not no_report_or_save:
        test_loss_extension = ComputeLossExtension(
            test_data,
            eos_idx,
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            observation_name="test_loss")
        trainer.extend(test_loss_extension,
                       trigger=(report_every, "iteration"))

        test_bleu_extension = ComputeBleuExtension(
            test_data,
            eos_idx,
            src_indexer,
            tgt_indexer,
            output_files_dict["test_translation_output"],
            output_files_dict["test_src_output"],
            mb_size,
            gpu,
            reverse_src,
            reverse_tgt,
            observation_name="test_bleu")

        trainer.extend(test_bleu_extension,
                       trigger=(report_every, "iteration"))

    if not no_report_or_save:
        trainer.extend(sample_extension, trigger=(sample_every, "iteration"))

        # trainer.extend(chainer.training.extensions.snapshot(), trigger = (save_ckpt_every, "iteration"))

        trainer.extend(CheckpontSavingExtension(
            output_files_dict["model_ckpt"], config_training),
                       trigger=(save_ckpt_every, "iteration"))

        trainer.extend(
            SqliteLogExtension(db_path=output_files_dict["sqlite_db"]))

    trainer.extend(
        TrainingLossSummaryExtension(trigger=(report_every, "iteration")))

    if config_training.training_management.resume:
        if "model_parameters" not in config_training:
            log.error("cannot find model parameters in config file")
            raise ValueError(
                "Config file do not contain model_parameters section")
        if config_training.model_parameters.type == "snapshot":
            model_filename = config_training.model_parameters.filename
            log.info("resuming from trainer parameters %s" % model_filename)
            serializers.load_npz(model_filename, trainer)

    if trainer_snapshot is not None:
        log.info("loading trainer parameters from %s" % trainer_snapshot)
        serializers.load_npz(trainer_snapshot, trainer)

    try:
        if save_initial_model_to is not None:
            log.info("Saving initial parameters to %s" % save_initial_model_to)
            encdec = trainer.updater.get_optimizer("main").target
            serializers.save_npz(save_initial_model_to, encdec)

        trainer.run()
    except BaseException:
        if not no_report_or_save:
            final_snapshot_fn = output_files_dict["model_final"]
            log.info(
                "Exception met. Trying to save current trainer state to file %s"
                % final_snapshot_fn)
            serializers.save_npz(final_snapshot_fn, trainer)
            #             chainer.training.extensions.snapshot(filename = final_snapshot_fn)(trainer)
            config_session = config_training.copy(readonly=False)
            config_session.add_section("model_parameters",
                                       keep_at_bottom="metadata")
            config_session["model_parameters"]["filename"] = final_snapshot_fn
            config_session["model_parameters"]["type"] = "snapshot"
            config_session["model_parameters"]["description"] = "final"
            config_session["model_parameters"][
                "infos"] = argument_parsing_tools.OrderedNamespace()
            config_session["model_parameters"]["infos"][
                "iteration"] = trainer.updater.iteration
            config_session.set_metadata_modified_time()
            config_session.save_to(final_snapshot_fn + ".config")
            # json.dump(config_session, open(final_snapshot_fn + ".config", "w"),
            # indent=2, separators=(',', ': '))
            log.info("Saved trainer snapshot to file %s" % final_snapshot_fn)
        raise