def _build_session_groups(self): """Returns a list of SessionGroups protobuffers from the summary data.""" # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name # (str) to a SessionGroup protobuffer. We traverse the runs associated with # the plugin--each representing a single session. We form a Session # protobuffer from each run and add it to the relevant SessionGroup object # in the 'groups_by_name' dict. We create the SessionGroup object, if this # is the first session of that group we encounter. groups_by_name = {} run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent( metadata.PLUGIN_NAME) for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): if metadata.SESSION_START_INFO_TAG not in tag_to_content: continue start_info = metadata.parse_session_start_info_plugin_data( tag_to_content[metadata.SESSION_START_INFO_TAG]) end_info = None if metadata.SESSION_END_INFO_TAG in tag_to_content: end_info = metadata.parse_session_end_info_plugin_data( tag_to_content[metadata.SESSION_END_INFO_TAG]) session = self._build_session(run, start_info, end_info) if session.status in self._request.allowed_statuses: self._add_session(session, start_info, groups_by_name) # Compute the session group's aggregated metrics for each group. groups = groups_by_name.values() for group in groups: # We sort the sessions in a group so that the order is deterministic. group.sessions.sort(key=operator.attrgetter('name')) self._aggregate_metrics(group) return groups
def _build_session_groups(self): """Returns a list of SessionGroups protobuffers from the summary data.""" # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name # (str) to a SessionGroup protobuffer. We traverse the runs associated with # the plugin--each representing a single session. We form a Session # protobuffer from each run and add it to the relevant SessionGroup object # in the 'groups_by_name' dict. We create the SessionGroup object, if this # is the first session of that group we encounter. groups_by_name = {} run_to_tag_to_content = self._context.hparams_metadata( self._experiment_id, run_tag_filter=provider.RunTagFilter(tags=[ metadata.SESSION_START_INFO_TAG, metadata.SESSION_END_INFO_TAG, ]), ) # The TensorBoard runs with session start info are the # "sessions", which are not necessarily the runs that actually # contain metrics (may be in subdirectories). session_names = [ run for (run, tags) in run_to_tag_to_content.items() if metadata.SESSION_START_INFO_TAG in tags ] metric_runs = set() metric_tags = set() for session_name in session_names: for metric in self._experiment.metric_infos: metric_name = metric.name (run, tag) = metrics.run_tag_from_session_and_metric( session_name, metric_name) metric_runs.add(run) metric_tags.add(tag) all_metric_evals = self._context.read_last_scalars( self._experiment_id, run_tag_filter=provider.RunTagFilter(runs=metric_runs, tags=metric_tags), ) for (session_name, tag_to_content) in run_to_tag_to_content.items(): if metadata.SESSION_START_INFO_TAG not in tag_to_content: continue start_info = metadata.parse_session_start_info_plugin_data( tag_to_content[metadata.SESSION_START_INFO_TAG]) end_info = None if metadata.SESSION_END_INFO_TAG in tag_to_content: end_info = metadata.parse_session_end_info_plugin_data( tag_to_content[metadata.SESSION_END_INFO_TAG]) session = self._build_session(session_name, start_info, end_info, all_metric_evals) if session.status in self._request.allowed_statuses: self._add_session(session, start_info, groups_by_name) # Compute the session group's aggregated metrics for each group. groups = groups_by_name.values() for group in groups: # We sort the sessions in a group so that the order is deterministic. group.sessions.sort(key=operator.attrgetter("name")) self._aggregate_metrics(group) return groups
def _compute_hparam_infos(self, hparams_run_to_tag_to_content): """Computes a list of api_pb2.HParamInfo from the current run, tag info. Finds all the SessionStartInfo messages and collects the hparams values appearing in each one. For each hparam attempts to deduce a type that fits all its values. Finally, sets the 'domain' of the resulting HParamInfo to be discrete if the type is string and the number of distinct values is small enough. Returns: A list of api_pb2.HParamInfo messages. """ # Construct a dict mapping an hparam name to its list of values. hparams = collections.defaultdict(list) for tag_to_content in hparams_run_to_tag_to_content.values(): if metadata.SESSION_START_INFO_TAG not in tag_to_content: continue start_info = metadata.parse_session_start_info_plugin_data( tag_to_content[metadata.SESSION_START_INFO_TAG]) for (name, value) in six.iteritems(start_info.hparams): hparams[name].append(value) # Try to construct an HParamInfo for each hparam from its name and list # of values. result = [] for (name, values) in six.iteritems(hparams): hparam_info = self._compute_hparam_info_from_values(name, values) if hparam_info is not None: result.append(hparam_info) return result
def get_group_name(hparams): summary_pb = hp.hparams_pb(hparams) values = summary_pb.value self.assertEqual(len(values), 1, values) actual_value = values[0] self.assertEqual( actual_value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, ) plugin_content = actual_value.metadata.plugin_data.content info = metadata.parse_session_start_info_plugin_data(plugin_content) return info.group_name
def _check_summary(self, summary_pb): """Test that a summary contains exactly the expected hparams PB.""" values = summary_pb.value self.assertEqual(len(values), 1, values) actual_value = values[0] self.assertEqual( actual_value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, ) plugin_content = actual_value.metadata.plugin_data.content info_pb = metadata.parse_session_start_info_plugin_data(plugin_content) # Ignore the `group_name` field; its properties are checked separately. info_pb.group_name = self.expected_session_start_pb.group_name self.assertEqual(info_pb, self.expected_session_start_pb)
def _build_session_infos_by_name(self): run_to_tag_to_content = self._context.multiplexer.PluginRunToTagToContent( metadata.PLUGIN_NAME) result = {} for (run, tag_to_content) in six.iteritems(run_to_tag_to_content): if metadata.SESSION_START_INFO_TAG not in tag_to_content: continue start_info = metadata.parse_session_start_info_plugin_data( tag_to_content[metadata.SESSION_START_INFO_TAG]) # end_info will be None if the corresponding tag doesn't exist. end_info = None if metadata.SESSION_END_INFO_TAG in tag_to_content: end_info = metadata.parse_session_end_info_plugin_data( tag_to_content[metadata.SESSION_END_INFO_TAG]) result[run] = self._SessionInfoTuple(start_info=start_info, end_info=end_info) return result
def _check_summary(self, summary_pb, check_group_name=False): """Test that a summary contains exactly the expected hparams PB.""" values = summary_pb.value self.assertEqual(len(values), 1, values) actual_value = values[0] self.assertEqual( actual_value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, ) self.assertEqual( tensor_pb2.TensorProto.FromString( actual_value.tensor.SerializeToString()), metadata.NULL_TENSOR, ) plugin_content = actual_value.metadata.plugin_data.content info_pb = metadata.parse_session_start_info_plugin_data(plugin_content) # Usually ignore the `group_name` field; its properties are checked # separately. if not check_group_name: info_pb.group_name = self.expected_session_start_pb.group_name self.assertEqual(info_pb, self.expected_session_start_pb)
def test_eager(self): def mock_time(): mock_time.time += 1 return mock_time.time mock_time.time = 1556227801.875 initial_time = mock_time.time with mock.patch("time.time", mock_time): self._initialize_model(writer=self.logdir) self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) final_time = mock_time.time files = os.listdir(self.logdir) self.assertEqual(len(files), 1, files) events_file = os.path.join(self.logdir, files[0]) plugin_data = [] for event in tf.compat.v1.train.summary_iterator(events_file): if event.WhichOneof("what") != "summary": continue self.assertEqual(len(event.summary.value), 1, event.summary.value) value = event.summary.value[0] self.assertEqual( value.metadata.plugin_data.plugin_name, metadata.PLUGIN_NAME, ) plugin_data.append(value.metadata.plugin_data.content) self.assertEqual(len(plugin_data), 2, plugin_data) (start_plugin_data, end_plugin_data) = plugin_data start_pb = metadata.parse_session_start_info_plugin_data(start_plugin_data) end_pb = metadata.parse_session_end_info_plugin_data(end_plugin_data) # We're not the only callers of `time.time`; Keras calls it # internally an unspecified number of times, so we're not guaranteed # to know the exact values. Instead, we perform relative checks... self.assertGreater(start_pb.start_time_secs, initial_time) self.assertLess(start_pb.start_time_secs, end_pb.end_time_secs) self.assertLessEqual(start_pb.start_time_secs, final_time) # ...and then stub out the times for proto equality checks below. start_pb.start_time_secs = 1234.5 end_pb.end_time_secs = 6789.0 expected_start_pb = plugin_data_pb2.SessionStartInfo() text_format.Merge( """ start_time_secs: 1234.5 group_name: "my_trial" hparams { key: "optimizer" value { string_value: "adam" } } hparams { key: "dense_neurons" value { number_value: 8.0 } } """, expected_start_pb, ) self.assertEqual(start_pb, expected_start_pb) expected_end_pb = plugin_data_pb2.SessionEndInfo() text_format.Merge( """ end_time_secs: 6789.0 status: STATUS_SUCCESS """, expected_end_pb, ) self.assertEqual(end_pb, expected_end_pb)