Пример #1
0
def test_context_memory(config_filename, search_engine, src_sentence):
    config_training = train_config.load_config_train(config_filename)
    (encdec, eos_idx, src_indexer, tgt_indexer
     ), model_infos = train.create_encdec_and_indexers_from_config_dict(
         config_training, load_config_model="yes", return_model_infos=True)

    ctxt_mem = create_context_memory(encdec, (src_indexer, tgt_indexer),
                                     search_engine, src_sentence)
Пример #2
0
def create_and_load_encdec_from_files(config_training_fn, trained_model):
    log.info("loading model config from %s" % config_training_fn)

    config_training = train_config.load_config_train(config_training_fn)
    encdec, eos_idx, src_indexer, tgt_indexer = train.create_encdec_and_indexers_from_config_dict(config_training)

    log.info("loading model from %s" % trained_model)
    serializers.load_npz(trained_model, encdec)

    return encdec, eos_idx, src_indexer, tgt_indexer
Пример #3
0
Файл: eval.py Проект: Tzawa/knmt
def create_encdec(config_eval):
    encdec_list = []
    eos_idx, src_indexer, tgt_indexer = None, None, None
    model_infos_list = []

    if config_eval.training_config is not None:
        assert config_eval.trained_model is not None
        encdec, eos_idx, src_indexer, tgt_indexer = create_and_load_encdec_from_files(
            config_eval.training_config, config_eval.trained_model)
        model_infos_list.append(
            create_filename_infos(config_eval.trained_model))
        encdec_list.append(encdec)

    if 'load_model_config' in config_eval.process and config_eval.process.load_model_config is not None:
        for config_filename_and_others in config_eval.process.load_model_config:
            other_models_for_averaging = None
            if "," in config_filename_and_others:
                config_filename_and_others_splitted = config_filename_and_others.split(
                    ",")
                config_filename = config_filename_and_others_splitted[0]
                other_models_for_averaging = config_filename_and_others_splitted[
                    1:]
            else:
                config_filename = config_filename_and_others
            log.info("loading model and parameters from config %s" %
                     config_filename)
            config_training = train_config.load_config_train(config_filename)
            (
                encdec, this_eos_idx, this_src_indexer, this_tgt_indexer
            ), model_infos = train.create_encdec_and_indexers_from_config_dict(
                config_training,
                load_config_model="yes",
                return_model_infos=True,
                additional_models_parameters_for_averaging=
                other_models_for_averaging)
            model_infos_list.append(model_infos)
            if eos_idx is None:
                assert len(encdec_list) == 0
                assert src_indexer is None
                assert tgt_indexer is None
                eos_idx, src_indexer, tgt_indexer = this_eos_idx, this_src_indexer, this_tgt_indexer
            else:
                check_if_vocabulary_info_compatible(this_eos_idx,
                                                    this_src_indexer,
                                                    this_tgt_indexer, eos_idx,
                                                    src_indexer, tgt_indexer)

            encdec_list.append(encdec)

    assert len(encdec_list) > 0

    if 'additional_training_config' in config_eval.process and config_eval.process.additional_training_config is not None:
        assert len(config_eval.process.additional_training_config) == len(
            config_eval.process.additional_trained_model)

        for (config_training_fn, trained_model_fn) in six.moves.zip(
                config_eval.process.additional_training_config,
                config_eval.process.additional_trained_model):
            this_encdec, this_eos_idx, this_src_indexer, this_tgt_indexer = create_and_load_encdec_from_files(
                config_training_fn, trained_model_fn)

            check_if_vocabulary_info_compatible(this_eos_idx, this_src_indexer,
                                                this_tgt_indexer, eos_idx,
                                                src_indexer, tgt_indexer)
            model_infos_list.append(create_filename_infos(trained_model_fn))

            encdec_list.append(this_encdec)

    if config_eval.process.use_chainerx:
        if 'gpu' in config_eval.process and config_eval.process.gpu is not None:
            encdec_list = [
                encdec.to_device("cuda:%i" % config_eval.process.gpu)
                for encdec in encdec_list
            ]
        else:
            encdec_list = [
                encdec.to_device("native:0") for encdec in encdec_list
            ]
    else:
        if 'gpu' in config_eval.process and config_eval.process.gpu is not None:
            encdec_list = [
                encdec.to_gpu(config_eval.process.gpu)
                for encdec in encdec_list
            ]

    if 'reverse_training_config' in config_eval.process and config_eval.process.reverse_training_config is not None:
        reverse_encdec, reverse_eos_idx, reverse_src_indexer, reverse_tgt_indexer = create_and_load_encdec_from_files(
            config_eval.process.reverse_training_config,
            config_eval.process.reverse_trained_model)

        if eos_idx != reverse_eos_idx:
            raise Exception("incompatible models")

        if len(src_indexer) != len(reverse_src_indexer):
            raise Exception("incompatible models")

        if len(tgt_indexer) != len(reverse_tgt_indexer):
            raise Exception("incompatible models")

        if config_eval.process.gpu is not None:
            reverse_encdec = reverse_encdec.to_gpu(config_eval.process.gpu)
    else:
        reverse_encdec = None

    return encdec_list, eos_idx, src_indexer, tgt_indexer, reverse_encdec, model_infos_list
Пример #4
0
def create_encdec(config_eval):
    encdec_list = []
    eos_idx, src_indexer, tgt_indexer = None, None, None
    model_infos_list = []

    if config_eval.training_config is not None:
        assert config_eval.trained_model is not None
        encdec, eos_idx, src_indexer, tgt_indexer = create_and_load_encdec_from_files(
            config_eval.training_config, config_eval.trained_model)
        model_infos_list.append(create_filename_infos(config_eval.trained_model))
        encdec_list.append(encdec)

    if 'load_model_config' in config_eval.process and config_eval.process.load_model_config is not None:
        for config_filename in config_eval.process.load_model_config:
            log.info(
                "loading model and parameters from config %s" %
                config_filename)
            config_training = train_config.load_config_train(config_filename)
            (encdec, this_eos_idx, this_src_indexer, this_tgt_indexer), model_infos = train.create_encdec_and_indexers_from_config_dict(config_training,
                                                                                                                                        load_config_model="yes",
                                                                                                                                        return_model_infos=True)
            model_infos_list.append(model_infos)
            if eos_idx is None:
                assert len(encdec_list) == 0
                assert src_indexer is None
                assert tgt_indexer is None
                eos_idx, src_indexer, tgt_indexer = this_eos_idx, this_src_indexer, this_tgt_indexer
            else:
                check_if_vocabulary_info_compatible(this_eos_idx, this_src_indexer, this_tgt_indexer, eos_idx, src_indexer, tgt_indexer)

            encdec_list.append(encdec)

    assert len(encdec_list) > 0

    if 'additional_training_config' in config_eval.process and config_eval.process.additional_training_config is not None:
        assert len(config_eval.process.additional_training_config) == len(config_eval.process.additional_trained_model)

        for (config_training_fn, trained_model_fn) in zip(config_eval.process.additional_training_config,
                                                          config_eval.process.additional_trained_model):
            this_encdec, this_eos_idx, this_src_indexer, this_tgt_indexer = create_and_load_encdec_from_files(
                config_training_fn, trained_model_fn)

            check_if_vocabulary_info_compatible(this_eos_idx, this_src_indexer, this_tgt_indexer, eos_idx, src_indexer, tgt_indexer)
            model_infos_list.append(create_filename_infos(trained_model_fn))
#             if args.gpu is not None:
#                 this_encdec = this_encdec.to_gpu(args.gpu)

            encdec_list.append(this_encdec)

    if 'gpu' in config_eval.process and config_eval.process.gpu is not None:
        encdec_list = [encdec.to_gpu(config_eval.process.gpu) for encdec in encdec_list]

    if 'reverse_training_config' in config_eval.process and config_eval.process.reverse_training_config is not None:
        reverse_encdec, reverse_eos_idx, reverse_src_indexer, reverse_tgt_indexer = create_and_load_encdec_from_files(
            config_eval.process.reverse_training_config, config_eval.process.reverse_trained_model)

        if eos_idx != reverse_eos_idx:
            raise Exception("incompatible models")

        if len(src_indexer) != len(reverse_src_indexer):
            raise Exception("incompatible models")

        if len(tgt_indexer) != len(reverse_tgt_indexer):
            raise Exception("incompatible models")

        if config_eval.process.gpu is not None:
            reverse_encdec = reverse_encdec.to_gpu(config_eval.process.gpu)
    else:
        reverse_encdec = None

    return encdec_list, eos_idx, src_indexer, tgt_indexer, reverse_encdec, model_infos_list