def to_proto(self, export_scope=None): # pylint: disable=unused-argument """Converts a `HParams` object to a `HParamDef` protocol buffer. Args: export_scope: Optional `string`. Name scope to remove. Returns: A `HParamDef` protocol buffer. """ hparam_proto = hparam_pb2.HParamDef() for name in self._hparam_types: # Parse the values. param_type, is_list = self._hparam_types.get(name, (None, None)) kind = HParams._get_kind_name(param_type, is_list) if is_list: if kind.startswith('bytes'): v_list = [compat.as_bytes(v) for v in getattr(self, name)] else: v_list = [v for v in getattr(self, name)] getattr(hparam_proto.hparam[name], kind).value.extend(v_list) else: v = getattr(self, name) if kind.startswith('bytes'): v = compat.as_bytes(getattr(self, name)) setattr(hparam_proto.hparam[name], kind, v) return hparam_proto
def load_hparams(checkpoint_dir: str) -> tf.contrib.training.HParams: """Load saved hyperparameters from a checkpoint.""" hparams_path = os.path.join(checkpoint_dir, 'hparams.pbtxt') hparam_def = hparam_pb2.HParamDef() with tf.gfile.GFile(hparams_path, 'r') as f: text_format.Merge(f.read(), hparam_def) hparams = tf.contrib.training.HParams(hparam_def) # Set any new hparams not found in the file with default values. return create_hparams(**hparams.values())
def _handle_hps_proto_file(): if FLAGS.hps_proto_file: hparams_proto = hparam_pb2.HParamDef() with tf.gfile.GFile(FLAGS.hps_proto_file) as f: text_format.Parse(f.read(), hparams_proto) hparams = contrib_training.HParams.from_proto(hparams_proto) hparams = _maybe_upgrade_hparams(hparams) for name, value in hparams.values().items(): if FLAGS[name].using_default_value: logging.info('hps_proto FLAGS.%s = %r', name, value) FLAGS[name].value = value
def load_hparams(hparams_path: str) -> tf.contrib.training.HParams: """Reads hparams protobuf from file. Args: hparams_path: Path to hparams file. Returns: Hparams object. """ hparam_def = hparam_pb2.HParamDef() with tf.gfile.GFile(hparams_path, 'r') as file: text_format.Merge(file.read(), hparam_def) hparams = tf.contrib.training.HParams(hparam_def) return hparams