Ejemplo n.º 1
0
  def _build_session_groups(self):
    """Returns a list of SessionGroups protobuffers from the summary data."""

    # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name
    # (str) to a SessionGroup protobuffer. We traverse the runs associated with
    # the plugin--each representing a single session. We form a Session
    # protobuffer from each run and add it to the relevant SessionGroup object
    # in the 'groups_by_name' dict. We create the SessionGroup object, if this
    # is the first session of that group we encounter.
    groups_by_name = {}
    run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent(
        metadata.PLUGIN_NAME)
    for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
      if metadata.SESSION_START_INFO_TAG not in tag_to_content:
        continue
      start_info = metadata.parse_session_start_info_plugin_data(
          tag_to_content[metadata.SESSION_START_INFO_TAG])
      end_info = None
      if metadata.SESSION_END_INFO_TAG in tag_to_content:
        end_info = metadata.parse_session_end_info_plugin_data(
            tag_to_content[metadata.SESSION_END_INFO_TAG])
      session = self._build_session(run, start_info, end_info)
      if session.status in self._request.allowed_statuses:
        self._add_session(session, start_info, groups_by_name)

    # Compute the session group's aggregated metrics for each group.
    groups = groups_by_name.values()
    for group in groups:
      # We sort the sessions in a group so that the order is deterministic.
      group.sessions.sort(key=operator.attrgetter('name'))
      self._aggregate_metrics(group)
    return groups
    def _build_session_groups(self):
        """Returns a list of SessionGroups protobuffers from the summary
        data."""

        # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name
        # (str) to a SessionGroup protobuffer. We traverse the runs associated with
        # the plugin--each representing a single session. We form a Session
        # protobuffer from each run and add it to the relevant SessionGroup object
        # in the 'groups_by_name' dict. We create the SessionGroup object, if this
        # is the first session of that group we encounter.
        groups_by_name = {}
        run_to_tag_to_content = self._context.hparams_metadata(
            self._experiment_id,
            run_tag_filter=provider.RunTagFilter(tags=[
                metadata.SESSION_START_INFO_TAG,
                metadata.SESSION_END_INFO_TAG,
            ]),
        )
        # The TensorBoard runs with session start info are the
        # "sessions", which are not necessarily the runs that actually
        # contain metrics (may be in subdirectories).
        session_names = [
            run for (run, tags) in run_to_tag_to_content.items()
            if metadata.SESSION_START_INFO_TAG in tags
        ]
        metric_runs = set()
        metric_tags = set()
        for session_name in session_names:
            for metric in self._experiment.metric_infos:
                metric_name = metric.name
                (run, tag) = metrics.run_tag_from_session_and_metric(
                    session_name, metric_name)
                metric_runs.add(run)
                metric_tags.add(tag)
        all_metric_evals = self._context.read_last_scalars(
            self._experiment_id,
            run_tag_filter=provider.RunTagFilter(runs=metric_runs,
                                                 tags=metric_tags),
        )
        for (session_name, tag_to_content) in run_to_tag_to_content.items():
            if metadata.SESSION_START_INFO_TAG not in tag_to_content:
                continue
            start_info = metadata.parse_session_start_info_plugin_data(
                tag_to_content[metadata.SESSION_START_INFO_TAG])
            end_info = None
            if metadata.SESSION_END_INFO_TAG in tag_to_content:
                end_info = metadata.parse_session_end_info_plugin_data(
                    tag_to_content[metadata.SESSION_END_INFO_TAG])
            session = self._build_session(session_name, start_info, end_info,
                                          all_metric_evals)
            if session.status in self._request.allowed_statuses:
                self._add_session(session, start_info, groups_by_name)

        # Compute the session group's aggregated metrics for each group.
        groups = groups_by_name.values()
        for group in groups:
            # We sort the sessions in a group so that the order is deterministic.
            group.sessions.sort(key=operator.attrgetter("name"))
            self._aggregate_metrics(group)
        return groups
Ejemplo n.º 3
0
    def _compute_hparam_infos(self, hparams_run_to_tag_to_content):
        """Computes a list of api_pb2.HParamInfo from the current run, tag
        info.

        Finds all the SessionStartInfo messages and collects the hparams values
        appearing in each one. For each hparam attempts to deduce a type that fits
        all its values. Finally, sets the 'domain' of the resulting HParamInfo
        to be discrete if the type is string and the number of distinct values is
        small enough.

        Returns:
          A list of api_pb2.HParamInfo messages.
        """
        # Construct a dict mapping an hparam name to its list of values.
        hparams = collections.defaultdict(list)
        for tag_to_content in hparams_run_to_tag_to_content.values():
            if metadata.SESSION_START_INFO_TAG not in tag_to_content:
                continue
            start_info = metadata.parse_session_start_info_plugin_data(
                tag_to_content[metadata.SESSION_START_INFO_TAG])
            for (name, value) in six.iteritems(start_info.hparams):
                hparams[name].append(value)

        # Try to construct an HParamInfo for each hparam from its name and list
        # of values.
        result = []
        for (name, values) in six.iteritems(hparams):
            hparam_info = self._compute_hparam_info_from_values(name, values)
            if hparam_info is not None:
                result.append(hparam_info)
        return result
Ejemplo n.º 4
0
 def get_group_name(hparams):
   summary_pb = hp.hparams_pb(hparams)
   values = summary_pb.value
   self.assertEqual(len(values), 1, values)
   actual_value = values[0]
   self.assertEqual(
       actual_value.metadata.plugin_data.plugin_name,
       metadata.PLUGIN_NAME,
   )
   plugin_content = actual_value.metadata.plugin_data.content
   info = metadata.parse_session_start_info_plugin_data(plugin_content)
   return info.group_name
Ejemplo n.º 5
0
 def _check_summary(self, summary_pb):
     """Test that a summary contains exactly the expected hparams PB."""
     values = summary_pb.value
     self.assertEqual(len(values), 1, values)
     actual_value = values[0]
     self.assertEqual(
         actual_value.metadata.plugin_data.plugin_name,
         metadata.PLUGIN_NAME,
     )
     plugin_content = actual_value.metadata.plugin_data.content
     info_pb = metadata.parse_session_start_info_plugin_data(plugin_content)
     # Ignore the `group_name` field; its properties are checked separately.
     info_pb.group_name = self.expected_session_start_pb.group_name
     self.assertEqual(info_pb, self.expected_session_start_pb)
Ejemplo n.º 6
0
 def _build_session_infos_by_name(self):
     run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent(
         metadata.PLUGIN_NAME)
     result = {}
     for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
         if metadata.SESSION_START_INFO_TAG not in tag_to_content:
             continue
         start_info = metadata.parse_session_start_info_plugin_data(
             tag_to_content[metadata.SESSION_START_INFO_TAG])
         # end_info will be None if the corresponding tag doesn't exist.
         end_info = None
         if metadata.SESSION_END_INFO_TAG in tag_to_content:
             end_info = metadata.parse_session_end_info_plugin_data(
                 tag_to_content[metadata.SESSION_END_INFO_TAG])
         result[run] = self._SessionInfoTuple(start_info=start_info,
                                              end_info=end_info)
     return result
Ejemplo n.º 7
0
 def _build_session_infos_by_name(self):
   run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent(
       metadata.PLUGIN_NAME)
   result = {}
   for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
     if metadata.SESSION_START_INFO_TAG not in tag_to_content:
       continue
     start_info = metadata.parse_session_start_info_plugin_data(
         tag_to_content[metadata.SESSION_START_INFO_TAG])
     # end_info will be None if the corresponding tag doesn't exist.
     end_info = None
     if metadata.SESSION_END_INFO_TAG in tag_to_content:
       end_info = metadata.parse_session_end_info_plugin_data(
           tag_to_content[metadata.SESSION_END_INFO_TAG])
     result[run] = self._SessionInfoTuple(start_info=start_info,
                                          end_info=end_info)
   return result
Ejemplo n.º 8
0
 def _check_summary(self, summary_pb, check_group_name=False):
     """Test that a summary contains exactly the expected hparams PB."""
     values = summary_pb.value
     self.assertEqual(len(values), 1, values)
     actual_value = values[0]
     self.assertEqual(
         actual_value.metadata.plugin_data.plugin_name,
         metadata.PLUGIN_NAME,
     )
     self.assertEqual(
         tensor_pb2.TensorProto.FromString(
             actual_value.tensor.SerializeToString()),
         metadata.NULL_TENSOR,
     )
     plugin_content = actual_value.metadata.plugin_data.content
     info_pb = metadata.parse_session_start_info_plugin_data(plugin_content)
     # Usually ignore the `group_name` field; its properties are checked
     # separately.
     if not check_group_name:
         info_pb.group_name = self.expected_session_start_pb.group_name
     self.assertEqual(info_pb, self.expected_session_start_pb)
Ejemplo n.º 9
0
  def test_eager(self):
    def mock_time():
      mock_time.time += 1
      return mock_time.time
    mock_time.time = 1556227801.875
    initial_time = mock_time.time
    with mock.patch("time.time", mock_time):
      self._initialize_model(writer=self.logdir)
      self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback])
    final_time = mock_time.time

    files = os.listdir(self.logdir)
    self.assertEqual(len(files), 1, files)
    events_file = os.path.join(self.logdir, files[0])
    plugin_data = []
    for event in tf.compat.v1.train.summary_iterator(events_file):
      if event.WhichOneof("what") != "summary":
        continue
      self.assertEqual(len(event.summary.value), 1, event.summary.value)
      value = event.summary.value[0]
      self.assertEqual(
          value.metadata.plugin_data.plugin_name,
          metadata.PLUGIN_NAME,
      )
      plugin_data.append(value.metadata.plugin_data.content)

    self.assertEqual(len(plugin_data), 2, plugin_data)
    (start_plugin_data, end_plugin_data) = plugin_data
    start_pb = metadata.parse_session_start_info_plugin_data(start_plugin_data)
    end_pb = metadata.parse_session_end_info_plugin_data(end_plugin_data)

    # We're not the only callers of `time.time`; Keras calls it
    # internally an unspecified number of times, so we're not guaranteed
    # to know the exact values. Instead, we perform relative checks...
    self.assertGreater(start_pb.start_time_secs, initial_time)
    self.assertLess(start_pb.start_time_secs, end_pb.end_time_secs)
    self.assertLessEqual(start_pb.start_time_secs, final_time)
    # ...and then stub out the times for proto equality checks below.
    start_pb.start_time_secs = 1234.5
    end_pb.end_time_secs = 6789.0

    expected_start_pb = plugin_data_pb2.SessionStartInfo()
    text_format.Merge(
        """
        start_time_secs: 1234.5
        group_name: "my_trial"
        hparams {
          key: "optimizer"
          value {
            string_value: "adam"
          }
        }
        hparams {
          key: "dense_neurons"
          value {
            number_value: 8.0
          }
        }
        """,
        expected_start_pb,
    )
    self.assertEqual(start_pb, expected_start_pb)

    expected_end_pb = plugin_data_pb2.SessionEndInfo()
    text_format.Merge(
        """
        end_time_secs: 6789.0
        status: STATUS_SUCCESS
        """,
        expected_end_pb,
    )
    self.assertEqual(end_pb, expected_end_pb)