コード例 #1
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Make sure that the loaded model has latest values for the below keys
    updated_keys = [
        "out_dir", "num_gpus", "test_prefix", "beam_width",
        "length_penalty_weight", "num_train_steps"
    ]
    for key in updated_keys:
        if key in default_config and getattr(hparams,
                                             key) != default_config[key]:
            utils.print_out(
                "# Updating hparams.%s: %s -> %s" %
                (key, str(getattr(hparams, key)), str(default_config[key])))
            setattr(hparams, key, default_config[key])
    return hparams
コード例 #2
0
def create_or_load_hparams(default_hparams, hparams_path):
  """Create hparams or load hparams from output_dir."""
  hparams = utils.maybe_parse_standard_hparams(default_hparams, hparams_path)
  hparams = extend_hparams(hparams)
  # Print HParams
  utils.print_hparams(hparams)
  return hparams
コード例 #3
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # Set num encoder/decoder layers (for old checkpoints)
    if hasattr(hparams, "num_layers"):
        if not hasattr(hparams, "num_encoder_layers"):
            hparams.add_hparam("num_encoder_layers", hparams.num_layers)
        if not hasattr(hparams, "num_decoder_layers"):
            hparams.add_hparam("num_decoder_layers", hparams.num_layers)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    if getattr(default_hparams, "override_loaded_hparams", None):
        overwritten_keys = default_config.keys()
    else:
        # For inference
        overwritten_keys = INFERENCE_KEYS

    for key in overwritten_keys:
        if getattr(hparams, key) != default_config[key]:
            utils.print_out("# Updating hparams.%s: %s -> %s" %
                            (key, str(getattr(hparams, key)),
                             str(default_config[key])))
            setattr(hparams, key, default_config[key])
    return hparams
コード例 #4
0
def create_or_load_hparams(load_dir, default_hparams, hparams_path,
                           save_hparams):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(load_dir)
    if not hparams:
        hparams = default_hparams
        # Override hparams values with existing standard hparams config
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
        hparams = process_input_path(hparams)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)
        hparams = process_input_path(hparams)

    # Save HParams
    if save_hparams:
        utils.save_hparams(default_hparams.out_dir, hparams)
        for metric in hparams.metrics:
            utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                               hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
def create_or_load_hparams(out_dir, default_hparams, flags):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir, verbose=not flags.chat)
    if not hparams:
        # Parse the ones from the command line
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams,
                                                     flags.hparams_path,
                                                     verbose=not flags.chat)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams, flags)

    # Save HParams
    utils.save_hparams(out_dir, hparams, verbose=not flags.chat)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                           hparams,
                           verbose=not flags.chat)

    # Print HParams
    if not flags.chat:
        utils.print_hparams(hparams)
    return hparams
コード例 #6
0
def create_or_load_hparams(out_dir,
                           default_hparams,
                           hparams_path,
                           save_hparams=True):
    """Create hparams or load hparams from out_dir."""
    print('[new hparams]\n')
    hparams = default_hparams
    hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
    hparams = extend_hparams(hparams)
    '''
  hparams = utils.load_hparams(out_dir)
  if not hparams:
    print('[new hparams]\n')
    hparams = default_hparams
    hparams = utils.maybe_parse_standard_hparams(
        hparams, hparams_path)
    hparams = extend_hparams(hparams)
  else:
    print('[load hparams]\n')
    hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path)
  '''
    # Save HParams
    if save_hparams:
        utils.save_hparams(out_dir, hparams)
        for metric in hparams.metrics:
            utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                               hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
コード例 #7
0
ファイル: interface.py プロジェクト: jiniaoxu/chichat
def create_or_load_hparams(out_dir, default_hparams, hparams_path):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir)

    # print(hparams); assert False #debug
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(
            hparams, hparams_path)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path)

    if FLAGS.inference_input_file:
        hparams.src_vocab_file = os.path.join(out_dir, "../data/vocab.cor")
        hparams.tgt_vocab_file = os.path.join(out_dir, "../data/vocab.man")
        hparams.out_dir = out_dir
        hparams.best_bleu_dir = os.path.join(out_dir, "best_bleu")
        hparams.train_prefix = os.path.join(out_dir, "../data/train")
        hparams.dev_prefix = os.path.join(out_dir, "../data/dev_test")
        hparams.vocab_prefix = os.path.join(out_dir, "../data/vocab")
        hparams.rc_vocab_file = os.path.join(out_dir, "../data/vocab.cor")
        hparams.test_prefix = os.path.join(out_dir, "../data/test")

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"), hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
コード例 #8
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path):
  """Make sure the loaded hparams is compatible with new changes.

  For
  compatible reason, if there are new fields in default_hparams, we add
  them to the current hparams.
  """

  default_hparams = utils.maybe_parse_standard_hparams(default_hparams,
                                                       hparams_path)
  default_config = default_hparams.values()
  config = hparams.values()
  for key in default_config:
    if key not in config:
      hparams.add_hparam(key, default_config[key])

  # Make sure that the loaded model has latest values for the below keys
  updated_keys = [
      "out_dir", "num_gpus", "beam_width", "length_penalty_weight",
      "num_train_steps", "train_data", "train_kb", "dev_data", "dev_kb",
      "infer_src_data", "infer_tar_data", "infer_kb", "self_play_eval_data",
      "self_play_eval_kb", "self_play_train_data", "self_play_train_kb",
      "vocab_file", "max_dialogue_len", "max_inference_len",
      "num_kb_fields_per_entry", "len_action", "self_play_model_dir",
      "max_dialogue_turns", "train_threadhold", "reward_discount",
      "do_selfplay", "self_play_batch_size", "self_play_update_batch_size",
      "self_play_eval_batch_size", "inference_output_file", "task_type",
      "self_play_pretrain_dir", "learning_rate", "colocate_gradients_with_ops",
      "immutable_model_reload_freq", "optimizer", "self_play_loss_method",
      "self_play_variable_method", "self_play_sl_multiplier", "batch_size",
      "log_device_placement", "metrics", "self_play_immutable_gpu",
      "learning_rate2", "learning_rate3", "infer_batch_size", "steps_per_stats",
      "train_reward_type", "rl_training", "dev_infer_src_data",
      "dev_infer_tar_data", "dev_infer_kb", "dev_self_play_eval_data",
      "dev_self_play_eval_kb", "test_infer_src_data", "test_infer_tar_data",
      "test_infer_kb", "test_self_play_eval_data", "test_self_play_eval_kb",
      "eval_prefix", "eval_forever", "selfplay_eval_output_file",
      "num_self_play_train_steps", "codalab"
  ]
  for key in updated_keys:
    if key in default_config and getattr(hparams, key) != default_config[key]:
      utils.print_out(
          "# Updating hparams.%s: %s -> %s" %
          (key, str(getattr(hparams, key)), str(default_config[key])))
      setattr(hparams, key, default_config[key])
  return hparams
コード例 #9
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path=None):
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    if default_hparams.override_loaded_hparams:
        for key in default_config:
            if getattr(hparams, key) != default_config[key]:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                                (key, str(getattr(
                                    hparams, key)), str(default_config[key])))
                setattr(hparams, key, default_config[key])
    return hparams
コード例 #10
0
def create_or_load_hparams(out_dir, default_hparams, hparams_path):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_bleu_dir"), hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
コード例 #11
0
def create_or_load_hparams(out_dir,
                           default_hparams,
                           hparams_path,
                           save_hparams=True):
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)

    if save_hparams:
        utils.save_hparams(out_dir, hparams)
        for metric in hparams.metrics:
            utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                               hparams)

    utils.print_hparams(hparams)
    return hparams
コード例 #12
0
ファイル: nmt_eval.py プロジェクト: zmxdream/parallax
def ensure_compatible_hparams(hparams, default_hparams, hparams_path):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    if default_hparams.override_loaded_hparams:
        for key in default_config:
            if getattr(hparams, key) != default_config[key]:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                                (key, str(getattr(
                                    hparams, key)), str(default_config[key])))
                setattr(hparams, key, default_config[key])
    return hparams