Ejemplo n.º 1
0
 def test_session_start_pb(self):
     start_time_secs = 314160
     session_start_info = plugin_data_pb2.SessionStartInfo(
         model_uri="//model/uri",
         group_name="session_group",
         start_time_secs=start_time_secs)
     session_start_info.hparams["param1"].string_value = "string"
     # TODO: Fix nondeterminism.
     # session_start_info.hparams["param2"].number_value = 5.0
     # session_start_info.hparams["param3"].bool_value = False
     self.assertEqual(
         summary.session_start_pb(
             hparams={
                 "param1": "string",
                 # "param2":5,
                 # "param3":False,
             },
             model_uri="//model/uri",
             group_name="session_group",
             start_time_secs=start_time_secs),
         tf.Summary(value=[
             tf.Summary.Value(
                 tag="_hparams_/session_start_info",
                 metadata=tf.SummaryMetadata(
                     plugin_data=tf.SummaryMetadata.PluginData(
                         plugin_name="hparams",
                         content=(plugin_data_pb2.HParamsPluginData(
                             version=0,
                             session_start_info=session_start_info).
                                  SerializeToString()))))
         ]))
Ejemplo n.º 2
0
    def setUp(self):
        self.logdir = os.path.join(self.get_temp_dir(), "logs")
        self.hparams = {
            hp.HParam("learning_rate", hp.RealInterval(1e-2, 1e-1)):
            0.02,
            hp.HParam("dense_layers", hp.IntInterval(2, 7)):
            5,
            hp.HParam("optimizer", hp.Discrete(["adam", "sgd"])):
            "adam",
            hp.HParam("who_knows_what"):
            "???",
            hp.HParam(
                "magic",
                hp.Discrete([False, True]),
                display_name="~*~ Magic ~*~",
                description="descriptive",
            ):
            True,
            "dropout":
            0.3,
        }
        self.normalized_hparams = {
            "learning_rate": 0.02,
            "dense_layers": 5,
            "optimizer": "adam",
            "who_knows_what": "???",
            "magic": True,
            "dropout": 0.3,
        }
        self.start_time_secs = 123.45
        self.trial_id = "psl27"

        self.expected_session_start_pb = plugin_data_pb2.SessionStartInfo()
        text_format.Merge(
            """
            hparams { key: "learning_rate" value { number_value: 0.02 } }
            hparams { key: "dense_layers" value { number_value: 5 } }
            hparams { key: "optimizer" value { string_value: "adam" } }
            hparams { key: "who_knows_what" value { string_value: "???" } }
            hparams { key: "magic" value { bool_value: true } }
            hparams { key: "dropout" value { number_value: 0.3 } }
            """,
            self.expected_session_start_pb,
        )
        self.expected_session_start_pb.group_name = self.trial_id
        self.expected_session_start_pb.start_time_secs = self.start_time_secs
Ejemplo n.º 3
0
def hparams_pb(hparams, trial_id=None, start_time_secs=None):
    # NOTE: Keep docs in sync with `hparams` above.
    """Create a summary encoding hyperparameter values for a single trial.

    Args:
      hparams: A `dict` mapping hyperparameters to the values used in this
        trial. Keys should be the names of `HParam` objects used in an
        experiment, or the `HParam` objects themselves. Values should be
        Python `bool`, `int`, `float`, or `string` values, depending on
        the type of the hyperparameter.
      trial_id: An optional `str` ID for the set of hyperparameter values
        used in this trial. Defaults to a hash of the hyperparameters.
      start_time_secs: The time that this trial started training, as
        seconds since epoch. Defaults to the current time.

    Returns:
      A TensorBoard `summary_pb2.Summary` message.
    """
    if start_time_secs is None:
        start_time_secs = time.time()
    hparams = _normalize_hparams(hparams)
    group_name = _derive_session_group_name(trial_id, hparams)

    session_start_info = plugin_data_pb2.SessionStartInfo(
        group_name=group_name, start_time_secs=start_time_secs,
    )
    for hp_name in sorted(hparams):
        hp_value = hparams[hp_name]
        if isinstance(hp_value, bool):
            session_start_info.hparams[hp_name].bool_value = hp_value
        elif isinstance(hp_value, (float, int)):
            session_start_info.hparams[hp_name].number_value = hp_value
        elif isinstance(hp_value, six.string_types):
            session_start_info.hparams[hp_name].string_value = hp_value
        else:
            raise TypeError(
                "hparams[%r] = %r, of unsupported type %r"
                % (hp_name, hp_value, type(hp_value))
            )

    return _summary_pb(
        metadata.SESSION_START_INFO_TAG,
        plugin_data_pb2.HParamsPluginData(
            session_start_info=session_start_info
        ),
    )
Ejemplo n.º 4
0
def session_start_pb(hparams,
                     model_uri="",
                     monitor_url="",
                     group_name="",
                     start_time_secs=None):
  """Creates a summary that contains a training session metadata information.
  One such summary per training session should be created. Each should have
  a different run.

  Arguments:
    hparams: A dictionary with string keys. Describes the hyperparameter values
             used in the session mappng each hyperparameter name to its value.
             Supported value types are  bool, int, float, or str.
    model_uri: See the comment for the field with the same name of
               plugin_data_pb2.SessionStartInfo.
    monitor_url: See the comment for the field with the same name of
                 plugin_data_pb2.SessionStartInfo.
    group_name:  See the comment for the field with the same name of
                 plugin_data_pb2.SessionStartInfo.
    start_time_secs: float. The time to use as the session start time.
                     Represented as seconds since the UNIX epoch. If None uses
                     the current time.
  Returns:
    Returns the summary protobuffer mentioned above.
  """
  if start_time_secs is None:
    start_time_secs = time.time()
  session_start_info = plugin_data_pb2.SessionStartInfo(
      model_uri=model_uri,
      monitor_url=monitor_url,
      group_name=group_name,
      start_time_secs=start_time_secs)
  for (hp_name, hp_val) in six.iteritems(hparams):
    if isinstance(hp_val, (float, int)):
      session_start_info.hparams[hp_name].number_value = hp_val
    elif isinstance(hp_val, six.string_types):
      session_start_info.hparams[hp_name].string_value = hp_val
    elif isinstance(hp_val, bool):
      session_start_info.hparams[hp_name].bool_value = hp_val
    else:
      raise TypeError('hparams[%s]=%s has type: %s which is not supported' %
                      (hp_name, hp_val, type(hp_val)))
  return _summary(metadata.SESSION_START_INFO_TAG,
                  plugin_data_pb2.HParamsPluginData(
                      session_start_info=session_start_info))
Ejemplo n.º 5
0
def session_start_pb(hparams,
                     model_uri="",
                     monitor_url="",
                     group_name="",
                     start_time_secs=None):
    """Constructs a SessionStartInfo protobuffer.

    Creates a summary that contains a training session metadata information.
    One such summary per training session should be created. Each should have
    a different run.

    Args:
      hparams: A dictionary with string keys. Describes the hyperparameter values
               used in the session, mapping each hyperparameter name to its value.
               Supported value types are  `bool`, `int`, `float`, `str`, `list`,
               `tuple`.
               The type of value must correspond to the type of hyperparameter
               (defined in the corresponding api_pb2.HParamInfo member of the
               Experiment protobuf) as follows:

                +-----------------+---------------------------------+
                |Hyperparameter   | Allowed (Python) value types    |
                |type             |                                 |
                +-----------------+---------------------------------+
                |DATA_TYPE_BOOL   | bool                            |
                |DATA_TYPE_FLOAT64| int, float                      |
                |DATA_TYPE_STRING | six.string_types, tuple, list   |
                +-----------------+---------------------------------+

               Tuple and list instances will be converted to their string
               representation.
      model_uri: See the comment for the field with the same name of
                 plugin_data_pb2.SessionStartInfo.
      monitor_url: See the comment for the field with the same name of
                   plugin_data_pb2.SessionStartInfo.
      group_name:  See the comment for the field with the same name of
                   plugin_data_pb2.SessionStartInfo.
      start_time_secs: float. The time to use as the session start time.
                       Represented as seconds since the UNIX epoch. If None uses
                       the current time.
    Returns:
      The summary protobuffer mentioned above.
    """
    if start_time_secs is None:
        start_time_secs = time.time()
    session_start_info = plugin_data_pb2.SessionStartInfo(
        model_uri=model_uri,
        monitor_url=monitor_url,
        group_name=group_name,
        start_time_secs=start_time_secs,
    )
    for (hp_name, hp_val) in six.iteritems(hparams):
        if isinstance(hp_val, (float, int)):
            session_start_info.hparams[hp_name].number_value = hp_val
        elif isinstance(hp_val, six.string_types):
            session_start_info.hparams[hp_name].string_value = hp_val
        elif isinstance(hp_val, bool):
            session_start_info.hparams[hp_name].bool_value = hp_val
        elif isinstance(hp_val, (list, tuple)):
            session_start_info.hparams[hp_name].string_value = str(hp_val)
        else:
            raise TypeError(
                "hparams[%s]=%s has type: %s which is not supported" %
                (hp_name, hp_val, type(hp_val)))
    return _summary(
        metadata.SESSION_START_INFO_TAG,
        plugin_data_pb2.HParamsPluginData(
            session_start_info=session_start_info),
    )
Ejemplo n.º 6
0
  def test_eager(self):
    def mock_time():
      mock_time.time += 1
      return mock_time.time
    mock_time.time = 1556227801.875
    initial_time = mock_time.time
    with mock.patch("time.time", mock_time):
      self._initialize_model(writer=self.logdir)
      self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback])
    final_time = mock_time.time

    files = os.listdir(self.logdir)
    self.assertEqual(len(files), 1, files)
    events_file = os.path.join(self.logdir, files[0])
    plugin_data = []
    for event in tf.compat.v1.train.summary_iterator(events_file):
      if event.WhichOneof("what") != "summary":
        continue
      self.assertEqual(len(event.summary.value), 1, event.summary.value)
      value = event.summary.value[0]
      self.assertEqual(
          value.metadata.plugin_data.plugin_name,
          metadata.PLUGIN_NAME,
      )
      plugin_data.append(value.metadata.plugin_data.content)

    self.assertEqual(len(plugin_data), 2, plugin_data)
    (start_plugin_data, end_plugin_data) = plugin_data
    start_pb = metadata.parse_session_start_info_plugin_data(start_plugin_data)
    end_pb = metadata.parse_session_end_info_plugin_data(end_plugin_data)

    # We're not the only callers of `time.time`; Keras calls it
    # internally an unspecified number of times, so we're not guaranteed
    # to know the exact values. Instead, we perform relative checks...
    self.assertGreater(start_pb.start_time_secs, initial_time)
    self.assertLess(start_pb.start_time_secs, end_pb.end_time_secs)
    self.assertLessEqual(start_pb.start_time_secs, final_time)
    # ...and then stub out the times for proto equality checks below.
    start_pb.start_time_secs = 1234.5
    end_pb.end_time_secs = 6789.0

    expected_start_pb = plugin_data_pb2.SessionStartInfo()
    text_format.Merge(
        """
        start_time_secs: 1234.5
        group_name: "my_trial"
        hparams {
          key: "optimizer"
          value {
            string_value: "adam"
          }
        }
        hparams {
          key: "dense_neurons"
          value {
            number_value: 8.0
          }
        }
        """,
        expected_start_pb,
    )
    self.assertEqual(start_pb, expected_start_pb)

    expected_end_pb = plugin_data_pb2.SessionEndInfo()
    text_format.Merge(
        """
        end_time_secs: 6789.0
        status: STATUS_SUCCESS
        """,
        expected_end_pb,
    )
    self.assertEqual(end_pb, expected_end_pb)