Beispiel #1
0
  def to_proto(self, export_scope=None):  # pylint: disable=unused-argument
    """Converts a `HParams` object to a `HParamDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Returns:
      A `HParamDef` protocol buffer.
    """
    hparam_proto = hparam_pb2.HParamDef()
    for name in self._hparam_types:
      # Parse the values.
      param_type, is_list = self._hparam_types.get(name, (None, None))
      kind = HParams._get_kind_name(param_type, is_list)

      if is_list:
        if kind.startswith('bytes'):
          v_list = [compat.as_bytes(v) for v in getattr(self, name)]
        else:
          v_list = [v for v in getattr(self, name)]
        getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
      else:
        v = getattr(self, name)
        if kind.startswith('bytes'):
          v = compat.as_bytes(getattr(self, name))
        setattr(hparam_proto.hparam[name], kind, v)

    return hparam_proto
Beispiel #2
0
def load_hparams(checkpoint_dir: str) -> tf.contrib.training.HParams:
    """Load saved hyperparameters from a checkpoint."""
    hparams_path = os.path.join(checkpoint_dir, 'hparams.pbtxt')
    hparam_def = hparam_pb2.HParamDef()
    with tf.gfile.GFile(hparams_path, 'r') as f:
        text_format.Merge(f.read(), hparam_def)
    hparams = tf.contrib.training.HParams(hparam_def)
    # Set any new hparams not found in the file with default values.
    return create_hparams(**hparams.values())
Beispiel #3
0
def _handle_hps_proto_file():
    if FLAGS.hps_proto_file:
        hparams_proto = hparam_pb2.HParamDef()
        with tf.gfile.GFile(FLAGS.hps_proto_file) as f:
            text_format.Parse(f.read(), hparams_proto)
        hparams = contrib_training.HParams.from_proto(hparams_proto)
        hparams = _maybe_upgrade_hparams(hparams)
        for name, value in hparams.values().items():
            if FLAGS[name].using_default_value:
                logging.info('hps_proto FLAGS.%s = %r', name, value)
                FLAGS[name].value = value
Beispiel #4
0
def load_hparams(hparams_path: str) -> tf.contrib.training.HParams:
  """Reads hparams protobuf from file.

  Args:
    hparams_path: Path to hparams file.

  Returns:
    Hparams object.
  """
  hparam_def = hparam_pb2.HParamDef()
  with tf.gfile.GFile(hparams_path, 'r') as file:
    text_format.Merge(file.read(), hparam_def)
  hparams = tf.contrib.training.HParams(hparam_def)
  return hparams