Ejemplo n.º 1
0
  def __init__(self, writer, hparams, trial_id=None):
    """Create a callback for logging hyperparameters to TensorBoard.

    As with the standard `tf.keras.callbacks.TensorBoard` class, each
    callback object is valid for only one call to `model.fit`.

    Args:
      writer: The `SummaryWriter` object to which hparams should be
        written, or a logdir (as a `str`) to be passed to
        `tf.summary.create_file_writer` to create such a writer.
      hparams: A `dict` mapping hyperparameters to the values used in
        this session. Keys should be the names of `HParam` objects used
        in an experiment, or the `HParam` objects themselves. Values
        should be Python `bool`, `int`, `float`, or `string` values,
        depending on the type of the hyperparameter.
      trial_id: An optional `str` ID for the set of hyperparameter
        values used in this trial. Defaults to a hash of the
        hyperparameters.

    Raises:
      ValueError: If two entries in `hparams` share the same
        hyperparameter name.
    """
    # Defer creating the actual summary until we write it, so that the
    # timestamp is correct. But create a "dry-run" first to fail fast in
    # case the `hparams` are invalid.
    self._hparams = dict(hparams)
    self._trial_id = trial_id
    summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id)
    if writer is None:
      raise TypeError("writer must be a `SummaryWriter` or `str`, not None")
    elif isinstance(writer, str):
      self._writer = tf.compat.v2.summary.create_file_writer(writer)
    else:
      self._writer = writer
Ejemplo n.º 2
0
 def test_duplicate_hparam_names_from_two_objects(self):
   hparams = {
       hp.HParam("foo"): 1,
       hp.HParam("foo"): 1,
   }
   with six.assertRaisesRegex(
       self, ValueError, "multiple values specified for hparam 'foo'"):
     hp.hparams_pb(hparams)
Ejemplo n.º 3
0
 def test_duplicate_hparam_names_across_object_and_string(self):
   hparams = {
       "foo": 1,
       hp.HParam("foo"): 1,
   }
   with six.assertRaisesRegex(
       self, ValueError, "multiple values specified for hparam 'foo'"):
     hp.hparams_pb(hparams)
Ejemplo n.º 4
0
 def test_serialize_tf_linspace_numpy(self):
     # Should be subsumed by `test_serialize_numpy_scalars`; separate
     # test because it's a common use case.
     hparams = {
         "f_default": tf.linspace(1.0, 2.0, 5).numpy()[0],
         "f32": tf.cast(tf.linspace(1.0, 2.0, 5), tf.float32).numpy()[0],
         "f64": tf.cast(tf.linspace(1.0, 2.0, 5), tf.float64).numpy()[0],
     }
     hp.hparams_pb(hparams)
Ejemplo n.º 5
0
 def test_serialize_numpy_scalars(self):
     hparams = {
         "i32": np.array([1, 2], dtype=np.int32)[0],
         "i64": np.array([1, 2], dtype=np.int64)[0],
         "f_default": np.linspace(1.0, 2.0, 5)[0],
         "f32": np.linspace(1.0, 2.0, 5, dtype=np.float32)[0],
         "f64": np.linspace(1.0, 2.0, 5, dtype=np.float64)[0],
         "bool": np.array([False, True])[0],
     }
     hp.hparams_pb(hparams)
Ejemplo n.º 6
0
 def test_consistency_across_string_key_and_object_key(self):
     hparams_1 = {
         hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])): "adam",
         "learning_rate": 0.02,
     }
     hparams_2 = {
         "optimizer": "adam",
         hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)): 0.02,
     }
     self.assert_hparams_summaries_equal(
         hp.hparams_pb(hparams_1, start_time_secs=self.start_time_secs),
         hp.hparams_pb(hparams_2, start_time_secs=self.start_time_secs),
     )
Ejemplo n.º 7
0
 def test_invariant_under_permutation(self):
     # In particular, the group name should be the same.
     hparams_1 = {
         "optimizer": "adam",
         "learning_rate": 0.02,
     }
     hparams_2 = {
         "learning_rate": 0.02,
         "optimizer": "adam",
     }
     self.assert_hparams_summaries_equal(
         hp.hparams_pb(hparams_1, start_time_secs=self.start_time_secs),
         hp.hparams_pb(hparams_2, start_time_secs=self.start_time_secs),
     )
Ejemplo n.º 8
0
 def test_pb_explicit_trial_id(self):
     result = hp.hparams_pb(
         self.hparams,
         trial_id=self.trial_id,
         start_time_secs=self.start_time_secs,
     )
     self._check_summary(result, check_group_name=True)
Ejemplo n.º 9
0
 def test_pb_is_tensorboard_copy_of_proto(self):
     result = hp.hparams_pb(
         self.hparams, start_time_secs=self.start_time_secs
     )
     self.assertIsInstance(result, summary_pb2.Summary)
     if tf is not None:
         self.assertNotIsInstance(result, tf.compat.v1.Summary)
Ejemplo n.º 10
0
def _HParamSessionStart(name, hparams):
    from tensorboard.plugins.hparams import summary_v2 as hp
    try:
        # pylint: disable=unexpected-keyword-arg
        return hp.hparams_pb(hparams, trial_id=name)
    except TypeError:
        return _legacy_hparams_pb(hparams, name)
Ejemplo n.º 11
0
 def get_group_name(hparams):
   summary_pb = hp.hparams_pb(hparams)
   values = summary_pb.value
   self.assertEqual(len(values), 1, values)
   actual_value = values[0]
   self.assertEqual(
       actual_value.metadata.plugin_data.plugin_name,
       metadata.PLUGIN_NAME,
   )
   plugin_content = actual_value.metadata.plugin_data.content
   info = metadata.parse_session_start_info_plugin_data(plugin_content)
   return info.group_name
Ejemplo n.º 12
0
    def test_hparams(self):
        old_event = event_pb2.Event()
        old_event.step = 0
        old_event.wall_time = 456.75
        hparams_pb = hparams_summary.hparams_pb({"optimizer": "adam"})
        # Simulate legacy event with no tensor content
        for v in hparams_pb.value:
            v.ClearField("tensor")
        old_event.summary.CopyFrom(hparams_pb)

        new_events = self._migrate_event(old_event)
        self.assertLen(new_events, 1)
        self.assertLen(new_events[0].summary.value, 1)
        value = new_events[0].summary.value[0]
        self.assertEqual(value.tensor, hparams_metadata.NULL_TENSOR)
        self.assertEqual(value.metadata.data_class,
                         summary_pb2.DATA_CLASS_TENSOR)
        self.assertEqual(
            value.metadata.plugin_data,
            hparams_pb.value[0].metadata.plugin_data,
        )
Ejemplo n.º 13
0
 def test_pb_invalid_trial_id(self):
     with six.assertRaisesRegex(
             self, TypeError, "`trial_id` should be a `str`, but got: 12"):
         hp.hparams_pb(self.hparams, trial_id=12)
Ejemplo n.º 14
0
 def test_pb_contents(self):
     result = hp.hparams_pb(self.hparams,
                            start_time_secs=self.start_time_secs)
     self._check_summary(result)
Ejemplo n.º 15
0
def Session(name, hparams):
    try:
        return hp.hparams_pb(hparams, trial_id=name)
    except TypeError:
        return LegacySession(name, hparams)