def create_or_load_hparams(load_dir, default_hparams, hparams_path,
                           save_hparams):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(load_dir)
    if not hparams:
        hparams = default_hparams
        # Override hparams values with existing standard hparams config
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
        hparams = process_input_path(hparams)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)
        hparams = process_input_path(hparams)

    # Save HParams
    if save_hparams:
        utils.save_hparams(default_hparams.out_dir, hparams)
        for metric in hparams.metrics:
            utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                               hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
def create_or_load_hparams(out_dir, default_hparams, flags):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir, verbose=not flags.chat)
    if not hparams:
        # Parse the ones from the command line
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams,
                                                     flags.hparams_path,
                                                     verbose=not flags.chat)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams, flags)

    # Save HParams
    utils.save_hparams(out_dir, hparams, verbose=not flags.chat)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                           hparams,
                           verbose=not flags.chat)

    # Print HParams
    if not flags.chat:
        utils.print_hparams(hparams)
    return hparams
Пример #3
0
def create_or_load_hparams(out_dir, default_hparams, hparams_path):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir)

    # print(hparams); assert False #debug
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(
            hparams, hparams_path)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path)

    if FLAGS.inference_input_file:
        hparams.src_vocab_file = os.path.join(out_dir, "../data/vocab.cor")
        hparams.tgt_vocab_file = os.path.join(out_dir, "../data/vocab.man")
        hparams.out_dir = out_dir
        hparams.best_bleu_dir = os.path.join(out_dir, "best_bleu")
        hparams.train_prefix = os.path.join(out_dir, "../data/train")
        hparams.dev_prefix = os.path.join(out_dir, "../data/dev_test")
        hparams.vocab_prefix = os.path.join(out_dir, "../data/vocab")
        hparams.rc_vocab_file = os.path.join(out_dir, "../data/vocab.cor")
        hparams.test_prefix = os.path.join(out_dir, "../data/test")

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"), hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
Пример #4
0
def create_or_load_hparams(out_dir, default_hparams, save_hparams=True):
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = extend_hparams(hparams)

    # Save HParams
    if save_hparams:
        utils.save_hparams(out_dir, hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
Пример #5
0
def run_prediction(input_file_path, output_file_path):
    infile = 'input_file'
    word_split(input_file_path, infile, jieba_split)

    model_dir = 'jb_attention'
    hparams = utils.load_hparams(model_dir)
    hparams.inference_indices = [i for i in range(150)]
    sample_src_dataset = inference.load_data(infile)
    log_device_placement = hparams.log_device_placement

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:
        if (hparams.encoder_type == 'gnmt'
                or hparams.attention_architecture in ['gnmt', 'gnmt_v2']):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == 'standard':
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError('Unknown attention architecture %s' %
                             (hparams.attention_architecture))

    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope=None)

    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)

    infer_sess = tf.Session(target='',
                            config=config_proto,
                            graph=infer_model.graph)

    with infer_model.graph.as_default():
        loaded_infer_model, global_step = model_helper.create_or_load_model(
            infer_model.model, model_dir, infer_sess, 'infer')

    iterator_feed_dict = {
        infer_model.src_placeholder: sample_src_dataset,
        infer_model.batch_size_placeholder: 1,
    }
    infer_sess.run(infer_model.iterator.initalizer,
                   feed_dict=iterator_feed_dict)

    while True:
        try:
            nmt_outputs, _ = infer_model.decode(infer_sess)
        except tf.errors.OutOfRangeError:
            break
Пример #6
0
    def __init__(self, data_dir, id_speaker, num_translations_per_input=1,
                 scope=None):
        self.data_dir = data_dir
        self.hparams = utils.load_hparams(self.data_dir)
        self.hparams.infer_batch_size = 1
        self.ckpt = tf.train.latest_checkpoint(data_dir)

        self.id = id_speaker
        self.num_translations_per_input = num_translations_per_input
        model_creator = gnmt_model.GNMTModel

        self.infer_model = model_helper.create_infer_model(model_creator,
                                                           self.hparams,
                                                           scope)
Пример #7
0
def create_or_load_hparams(out_dir, default_hparams, hparams_path):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_bleu_dir"), hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
Пример #8
0
def modelInit():

    out_dir = 'tmp/nmt_attention_model'

    print("loading parameters")
    hparams = utils.load_hparams(out_dir)
    ckpt = tf.train.latest_checkpoint(out_dir)

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope=None)
    return infer_model, ckpt, hparams
Пример #9
0
def create_or_load_hparams(out_dir,
                           default_hparams,
                           hparams_path,
                           save_hparams=True):
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)

    if save_hparams:
        utils.save_hparams(out_dir, hparams)
        for metric in hparams.metrics:
            utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                               hparams)

    utils.print_hparams(hparams)
    return hparams
Пример #10
0
def create_or_load_hparams(out_dir, default_hparams):
    """
    Create hparams or load hparams from out_dir.
    """

    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams

        hparams.add_hparam("best_bleu", 0)
        best_bleu_dir = os.path.join(out_dir, "best_bleu")
        hparams.add_hparam("best_bleu_dir", best_bleu_dir)
        os.makedirs(best_bleu_dir)
        hparams.add_hparam("avg_best_bleu", 0)
        best_bleu_dir = os.path.join(hparams.out_dir, "avg_best_bleu")
        hparams.add_hparam("avg_best_bleu_dir",
                           os.path.join(hparams.out_dir, "avg_best_bleu"))
        os.makedirs(best_bleu_dir)

        # Set num_train_steps
        train_src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
        train_tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
        with open(train_src_file, 'r', encoding='utf-8') as f:
            train_src_steps = len(f.readlines())
        with open(train_tgt_file, 'r', encoding='utf-8') as f:
            train_tgt_steps = len(f.readlines())
        hparams.add_hparam(
            "num_train_steps",
            min([train_src_steps, train_tgt_steps]) * hparams.epochs)

        # Set encoder/decoder layers
        hparams.add_hparam("num_encoder_layers", hparams.num_layers)
        hparams.add_hparam("num_decoder_layers", hparams.num_layers)

        # Set residual layers
        num_encoder_residual_layers = 0
        num_decoder_residual_layers = 0
        if hparams.num_encoder_layers > 1:
            num_encoder_residual_layers = hparams.num_encoder_layers - 1
        if hparams.num_decoder_layers > 1:
            num_decoder_residual_layers = hparams.num_decoder_layers - 1

        # The first unidirectional layer (after the bi-directional layer) in
        # the GNMT encoder can't have residual connection due to the input is
        # the concatenation of fw_cell and bw_cell's outputs.
        num_encoder_residual_layers = hparams.num_encoder_layers - 2

        # Compatible for GNMT models
        if hparams.num_encoder_layers == hparams.num_decoder_layers:
            num_decoder_residual_layers = num_encoder_residual_layers

        hparams.add_hparam("num_encoder_residual_layers",
                           num_encoder_residual_layers)
        hparams.add_hparam("num_decoder_residual_layers",
                           num_decoder_residual_layers)

        # Vocab
        # Get vocab file names first
        if hparams.vocab_prefix:
            src_vocab_file = hparams.vocab_prefix + "." + hparams.src
            tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
        else:
            raise ValueError("hparams.vocab_prefix must be provided.")
        # Source vocab
        src_vocab_size, src_vocab_file = vocab_utils.check_vocab(
            src_vocab_file,
            hparams.out_dir,
            sos=hparams.sos,
            eos=hparams.eos,
            unk=vocab_utils.UNK)
        # Target vocab
        if hparams.share_vocab:
            utils.log("Using source vocab for target")
            tgt_vocab_file = src_vocab_file
            tgt_vocab_size = src_vocab_size
        else:
            tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(
                tgt_vocab_file,
                hparams.out_dir,
                sos=hparams.sos,
                eos=hparams.eos,
                unk=vocab_utils.UNK)
        hparams.add_hparam("src_vocab_size", src_vocab_size)
        hparams.add_hparam("tgt_vocab_size", tgt_vocab_size)
        hparams.add_hparam("src_vocab_file", src_vocab_file)
        hparams.add_hparam("tgt_vocab_file", tgt_vocab_file)

        # Pretrained Embeddings:
        hparams.add_hparam("src_embed_file", "")
        hparams.add_hparam("tgt_embed_file", "")
        if hparams.embed_prefix:
            src_embed_file = hparams.embed_prefix + "." + hparams.src
            tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt
            if os.path.exists(src_embed_file):
                hparams.src_embed_file = src_embed_file
            if os.path.exists(tgt_embed_file):
                hparams.tgt_embed_file = tgt_embed_file

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    return hparams