示例#1
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # Set num encoder/decoder layers (for old checkpoints)
    if hasattr(hparams, "num_layers"):
        if not hasattr(hparams, "num_encoder_layers"):
            hparams.add_hparam("num_encoder_layers", hparams.num_layers)
        if not hasattr(hparams, "num_decoder_layers"):
            hparams.add_hparam("num_decoder_layers", hparams.num_layers)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    overwritten_keys = None
    if getattr(default_hparams, "override_loaded_hparams", None):
        overwritten_keys = default_config.keys()

    if overwritten_keys is not None:
        for key in overwritten_keys:
            if getattr(hparams, key) != default_config[key]:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                                (key, str(getattr(
                                    hparams, key)), str(default_config[key])))
                setattr(hparams, key, default_config[key])
    return hparams
def ensure_compatible_hparams(hparams, default_hparams, flags):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, flags.hparams_path, verbose=not flags.chat)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Make sure that the loaded model has latest values for the below keys
    updated_keys = [
        "out_dir", "num_gpus", "test_prefix", "beam_width",
        "length_penalty_weight", "num_train_steps", "number_token",
        "name_token", "gpe_token", "UNAME", "TOKEN"
    ]
    for key in updated_keys:
        if key in default_config and getattr(hparams, key) != default_config[key]:
            if not flags.chat:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                            (key, str(getattr(hparams, key)), str(default_config[key])))
            setattr(hparams, key, default_config[key])
    return hparams
示例#3
0
def create_or_load_hparams(default_hparams, hparams_path, save_hparams=True):
    """Create hparams or load hparams from out_dir."""
    hparams = None
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)
    hparams = extend_hparams(hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
示例#4
0
def create_hparams_Alveo(default_hparams, hparams_path, save_hparams=True):
    """Create hparams or load hparams from out_dir."""

    hparams = default_hparams
    hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
    hparams = extend_hparams(hparams)

    if save_hparams:
        #utils.save_hparams_Alveo(hparams, hparams.out_hparam)
        utils.save_hparams(default_hparams.out_dir, hparams)
        for metric in hparams.metrics:
            utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                               hparams)

    utils.print_hparams(hparams)
    return hparams
示例#5
0
def create_or_load_hparams(out_dir, default_hparams, hparams_path):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams,
                                            hparams_path)

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"),
                           hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
def create_or_load_hparams(out_dir, default_hparams, flags):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir, verbose=not flags.chat)
    if not hparams:
        # Parse the ones from the command line
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(
            hparams, flags.hparams_path, verbose=not flags.chat)
        hparams = extend_hparams(hparams)
    else:
        hparams = ensure_compatible_hparams(hparams, default_hparams, flags)

    # Save HParams
    utils.save_hparams(out_dir, hparams, verbose=not flags.chat)

    for metric in hparams.metrics:
        utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"), hparams, verbose=not flags.chat)

    # Print HParams
    if not flags.chat:
        utils.print_hparams(hparams)
    return hparams
示例#7
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    if default_hparams.override_loaded_hparams:
        for key in default_config:
            if getattr(hparams, key) != default_config[key]:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                                (key, str(getattr(
                                    hparams, key)), str(default_config[key])))
                setattr(hparams, key, default_config[key])
    return hparams
示例#8
0
def load_hparams(hparams_path, default_hparams):
    """
  Loads hyperparameters from the specified path
  """
    return utils.maybe_parse_standard_hparams(default_hparams, hparams_path)