def _build_session_groups(self):
    """Returns a list of SessionGroups protobuffers from the summary data."""

    # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name
    # (str) to a SessionGroup protobuffer. We traverse the runs associated with
    # the plugin--each representing a single session. We form a Session
    # protobuffer from each run and add it to the relevant SessionGroup object
    # in the 'groups_by_name' dict. We create the SessionGroup object, if this
    # is the first session of that group we encounter.
    groups_by_name = {}
    run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent(
        metadata.PLUGIN_NAME)
    for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
      if metadata.SESSION_START_INFO_TAG not in tag_to_content:
        continue
      start_info = metadata.parse_session_start_info_plugin_data(
          tag_to_content[metadata.SESSION_START_INFO_TAG])
      end_info = None
      if metadata.SESSION_END_INFO_TAG in tag_to_content:
        end_info = metadata.parse_session_end_info_plugin_data(
            tag_to_content[metadata.SESSION_END_INFO_TAG])
      session = self._build_session(run, start_info, end_info)
      if session.status in self._request.allowed_statuses:
        self._add_session(session, start_info, groups_by_name)

    # Compute the session group's aggregated metrics for each group.
    groups = groups_by_name.values()
    for group in groups:
      # We sort the sessions in a group so that the order is deterministic.
      group.sessions.sort(key=operator.attrgetter('name'))
      self._aggregate_metrics(group)
    return groups
    def _build_session_groups(self):
        """Returns a list of SessionGroups protobuffers from the summary
        data."""

        # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name
        # (str) to a SessionGroup protobuffer. We traverse the runs associated with
        # the plugin--each representing a single session. We form a Session
        # protobuffer from each run and add it to the relevant SessionGroup object
        # in the 'groups_by_name' dict. We create the SessionGroup object, if this
        # is the first session of that group we encounter.
        groups_by_name = {}
        run_to_tag_to_content = self._context.hparams_metadata(
            self._experiment_id,
            run_tag_filter=provider.RunTagFilter(tags=[
                metadata.SESSION_START_INFO_TAG,
                metadata.SESSION_END_INFO_TAG,
            ]),
        )
        # The TensorBoard runs with session start info are the
        # "sessions", which are not necessarily the runs that actually
        # contain metrics (may be in subdirectories).
        session_names = [
            run for (run, tags) in run_to_tag_to_content.items()
            if metadata.SESSION_START_INFO_TAG in tags
        ]
        metric_runs = set()
        metric_tags = set()
        for session_name in session_names:
            for metric in self._experiment.metric_infos:
                metric_name = metric.name
                (run, tag) = metrics.run_tag_from_session_and_metric(
                    session_name, metric_name)
                metric_runs.add(run)
                metric_tags.add(tag)
        all_metric_evals = self._context.read_last_scalars(
            self._experiment_id,
            run_tag_filter=provider.RunTagFilter(runs=metric_runs,
                                                 tags=metric_tags),
        )
        for (session_name, tag_to_content) in run_to_tag_to_content.items():
            if metadata.SESSION_START_INFO_TAG not in tag_to_content:
                continue
            start_info = metadata.parse_session_start_info_plugin_data(
                tag_to_content[metadata.SESSION_START_INFO_TAG])
            end_info = None
            if metadata.SESSION_END_INFO_TAG in tag_to_content:
                end_info = metadata.parse_session_end_info_plugin_data(
                    tag_to_content[metadata.SESSION_END_INFO_TAG])
            session = self._build_session(session_name, start_info, end_info,
                                          all_metric_evals)
            if session.status in self._request.allowed_statuses:
                self._add_session(session, start_info, groups_by_name)

        # Compute the session group's aggregated metrics for each group.
        groups = groups_by_name.values()
        for group in groups:
            # We sort the sessions in a group so that the order is deterministic.
            group.sessions.sort(key=operator.attrgetter("name"))
            self._aggregate_metrics(group)
        return groups
Beispiel #3
0
 def _build_session_infos_by_name(self):
     run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent(
         metadata.PLUGIN_NAME)
     result = {}
     for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
         if metadata.SESSION_START_INFO_TAG not in tag_to_content:
             continue
         start_info = metadata.parse_session_start_info_plugin_data(
             tag_to_content[metadata.SESSION_START_INFO_TAG])
         # end_info will be None if the corresponding tag doesn't exist.
         end_info = None
         if metadata.SESSION_END_INFO_TAG in tag_to_content:
             end_info = metadata.parse_session_end_info_plugin_data(
                 tag_to_content[metadata.SESSION_END_INFO_TAG])
         result[run] = self._SessionInfoTuple(start_info=start_info,
                                              end_info=end_info)
     return result
 def _build_session_infos_by_name(self):
   run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent(
       metadata.PLUGIN_NAME)
   result = {}
   for (run, tag_to_content) in six.iteritems(run_to_tag_to_content):
     if metadata.SESSION_START_INFO_TAG not in tag_to_content:
       continue
     start_info = metadata.parse_session_start_info_plugin_data(
         tag_to_content[metadata.SESSION_START_INFO_TAG])
     # end_info will be None if the corresponding tag doesn't exist.
     end_info = None
     if metadata.SESSION_END_INFO_TAG in tag_to_content:
       end_info = metadata.parse_session_end_info_plugin_data(
           tag_to_content[metadata.SESSION_END_INFO_TAG])
     result[run] = self._SessionInfoTuple(start_info=start_info,
                                          end_info=end_info)
   return result
Beispiel #5
0
  def test_eager(self):
    def mock_time():
      mock_time.time += 1
      return mock_time.time
    mock_time.time = 1556227801.875
    initial_time = mock_time.time
    with mock.patch("time.time", mock_time):
      self._initialize_model(writer=self.logdir)
      self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback])
    final_time = mock_time.time

    files = os.listdir(self.logdir)
    self.assertEqual(len(files), 1, files)
    events_file = os.path.join(self.logdir, files[0])
    plugin_data = []
    for event in tf.compat.v1.train.summary_iterator(events_file):
      if event.WhichOneof("what") != "summary":
        continue
      self.assertEqual(len(event.summary.value), 1, event.summary.value)
      value = event.summary.value[0]
      self.assertEqual(
          value.metadata.plugin_data.plugin_name,
          metadata.PLUGIN_NAME,
      )
      plugin_data.append(value.metadata.plugin_data.content)

    self.assertEqual(len(plugin_data), 2, plugin_data)
    (start_plugin_data, end_plugin_data) = plugin_data
    start_pb = metadata.parse_session_start_info_plugin_data(start_plugin_data)
    end_pb = metadata.parse_session_end_info_plugin_data(end_plugin_data)

    # We're not the only callers of `time.time`; Keras calls it
    # internally an unspecified number of times, so we're not guaranteed
    # to know the exact values. Instead, we perform relative checks...
    self.assertGreater(start_pb.start_time_secs, initial_time)
    self.assertLess(start_pb.start_time_secs, end_pb.end_time_secs)
    self.assertLessEqual(start_pb.start_time_secs, final_time)
    # ...and then stub out the times for proto equality checks below.
    start_pb.start_time_secs = 1234.5
    end_pb.end_time_secs = 6789.0

    expected_start_pb = plugin_data_pb2.SessionStartInfo()
    text_format.Merge(
        """
        start_time_secs: 1234.5
        group_name: "my_trial"
        hparams {
          key: "optimizer"
          value {
            string_value: "adam"
          }
        }
        hparams {
          key: "dense_neurons"
          value {
            number_value: 8.0
          }
        }
        """,
        expected_start_pb,
    )
    self.assertEqual(start_pb, expected_start_pb)

    expected_end_pb = plugin_data_pb2.SessionEndInfo()
    text_format.Merge(
        """
        end_time_secs: 6789.0
        status: STATUS_SUCCESS
        """,
        expected_end_pb,
    )
    self.assertEqual(end_pb, expected_end_pb)