def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self): # If there are multiple `SummaryMetadata` for a given tag, and the # set of plugins in the `plugin_data` of second is different from # that of the first, then the second set should be ignored. logdir = self.get_temp_dir() summary_metadata_1 = tf.SummaryMetadata( display_name='current tagee', summary_description='no', plugin_data=tf.SummaryMetadata.PluginData(plugin_name='outlet', content=b'120v')) self._writeMetadata(logdir, summary_metadata_1, nonce='1') acc = ea.EventAccumulator(logdir) acc.Reload() summary_metadata_2 = tf.SummaryMetadata( display_name='tagee of the future', summary_description='definitely not', plugin_data=tf.SummaryMetadata.PluginData(plugin_name='plug', content=b'110v')) self._writeMetadata(logdir, summary_metadata_2, nonce='2') acc.Reload() self.assertEqual(acc.PluginTagToContent('outlet'), {'you_are_it': b'120v'}) with six.assertRaisesRegex(self, KeyError, 'plug'): acc.PluginTagToContent('plug')
def test_session_start_pb(self): start_time_secs = 314160 session_start_info = plugin_data_pb2.SessionStartInfo( model_uri="//model/uri", group_name="session_group", start_time_secs=start_time_secs) session_start_info.hparams["param1"].string_value = "string" # TODO: Fix nondeterminism. # session_start_info.hparams["param2"].number_value = 5.0 # session_start_info.hparams["param3"].bool_value = False self.assertEqual( summary.session_start_pb( hparams={ "param1": "string", # "param2":5, # "param3":False, }, model_uri="//model/uri", group_name="session_group", start_time_secs=start_time_secs), tf.Summary(value=[ tf.Summary.Value( tag="_hparams_/session_start_info", metadata=tf.SummaryMetadata( plugin_data=tf.SummaryMetadata.PluginData( plugin_name="hparams", content=(plugin_data_pb2.HParamsPluginData( version=0, session_start_info=session_start_info). SerializeToString())))) ]))
def add_3dvolume(self, volume, tag, global_step=None, walltime=None): filename = tag + "_" if global_step is None: filename += datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') else: filename += str(global_step) if isinstance(volume, torch.Tensor): volume = volume.detach().cpu().numpy() img = ants.from_numpy(volume) ants.image_write(img, os.path.join(self._log_dir, filename + ".nii.gz")) plugin_data = tf.SummaryMetadata.PluginData( plugin_name="tb_3d_volume_plugin", content=TextPluginData(version=0).SerializeToString()) metadata = tf.SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype='DT_STRING', string_val=[filename.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)])) summary = summary_pb2.Summary(value=[ summary_pb2.Summary.Value( tag=tag, metadata=metadata, tensor=tensor) ]) self._file_writer.add_summary(summary, global_step=global_step, walltime=walltime) self._file_writer.flush()
def _testTFSummaryTensor_SizeGuidance(self, plugin_name, tensor_size_guidance, steps, expected_count): event_sink = _EventGenerator(self, zero_out_timestamps=True) writer = tf.summary.FileWriter(self.get_temp_dir()) writer.event_writer = event_sink with self.test_session() as sess: summary_metadata = tf.SummaryMetadata( plugin_data=tf.SummaryMetadata.PluginData(plugin_name=plugin_name, content=b'{}')) tf.summary.tensor_summary('scalar', tf.constant(1.0), summary_metadata=summary_metadata) merged = tf.summary.merge_all() for step in xrange(steps): writer.add_summary(sess.run(merged), global_step=step) accumulator = ea.EventAccumulator( event_sink, tensor_size_guidance=tensor_size_guidance) accumulator.Reload() tensors = accumulator.Tensors('scalar') self.assertEqual(len(tensors), expected_count)
def testSummaryMetadata(self): logdir = self.get_temp_dir() summary_metadata = tf.SummaryMetadata(display_name='current tagee', summary_description='no') summary_metadata.plugin_data.plugin_name = 'outlet' self._writeMetadata(logdir, summary_metadata) acc = ea.EventAccumulator(logdir) acc.Reload() self.assertProtoEquals(summary_metadata, acc.SummaryMetadata('you_are_it'))
def testSummaryMetadata_FirstMetadataWins(self): logdir = self.get_temp_dir() summary_metadata_1 = tf.SummaryMetadata( display_name='current tagee', summary_description='no', plugin_data=tf.SummaryMetadata.PluginData(plugin_name='outlet', content=b'120v')) self._writeMetadata(logdir, summary_metadata_1, nonce='1') acc = ea.EventAccumulator(logdir) acc.Reload() summary_metadata_2 = tf.SummaryMetadata( display_name='tagee of the future', summary_description='definitely not', plugin_data=tf.SummaryMetadata.PluginData(plugin_name='plug', content=b'110v')) self._writeMetadata(logdir, summary_metadata_2, nonce='2') acc.Reload() self.assertProtoEquals(summary_metadata_1, acc.SummaryMetadata('you_are_it'))
def create_summary_metadata(display_name, description): """Create a `tf.SummaryMetadata` proto for text plugin data. Returns: A `tf.SummaryMetadata` protobuf object. """ content = plugin_data_pb2.TextPluginData(version=PROTO_VERSION) metadata = tf.SummaryMetadata(display_name=display_name, summary_description=description, plugin_data=tf.SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString())) return metadata
def test_fully_populated_tensor(self): metadata = tf.SummaryMetadata( plugin_data=tf.SummaryMetadata.PluginData( plugin_name='font_of_wisdom', content=b'adobe_garamond')) op = tf.summary.tensor_summary( name='tensorpocalypse', tensor=tf.constant([[0.0, 2.0], [float('inf'), float('nan')]]), display_name='TENSORPOCALYPSE', summary_description='look on my works ye mighty and despair', summary_metadata=metadata) value = self._value_from_op(op) assert value.HasField('tensor'), value self._assert_noop(value)
def create_summary_metadata(hparams_plugin_data_pb): """Creates a tf.SummaryMetadata holding a copy of the given HParamsPluginData message in its plugin_data.content field. Sets the version field of the hparams_plugin_data_pb copy to PLUGIN_DATA_VERSION. """ if not isinstance(hparams_plugin_data_pb, plugin_data_pb2.HParamsPluginData): raise TypeError( 'Needed an instance of plugin_data_pb2.HParamsPluginData.' ' Got: %s' % type(hparams_plugin_data_pb)) content = plugin_data_pb2.HParamsPluginData() content.CopyFrom(hparams_plugin_data_pb) content.version = PLUGIN_DATA_VERSION return tf.SummaryMetadata(plugin_data=tf.SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString()))
def test_session_end_pb(self): end_time_secs = 1234.0 self.assertEqual( summary.session_end_pb(api_pb2.STATUS_SUCCESS, end_time_secs), tf.Summary(value=[ tf.Summary.Value( tag="_hparams_/session_end_info", metadata=tf.SummaryMetadata( plugin_data=tf.SummaryMetadata.PluginData( plugin_name="hparams", content=(plugin_data_pb2.HParamsPluginData( version=0, session_end_info=( plugin_data_pb2.SessionEndInfo( status=api_pb2.STATUS_SUCCESS, end_time_secs=end_time_secs, ))).SerializeToString())))) ]))
def create_summary_metadata(display_name, description, num_thresholds): """Create a `tf.SummaryMetadata` proto for pr_curves plugin data. Arguments: display_name: The display name used in TensorBoard. description: The description to show in TensorBoard. num_thresholds: The number of thresholds to use for PR curves. Returns: A `tf.SummaryMetadata` protobuf object. """ pr_curve_plugin_data = plugin_data_pb2.PrCurvePluginData( version=PROTO_VERSION, num_thresholds=num_thresholds) content = pr_curve_plugin_data.SerializeToString() return tf.SummaryMetadata(display_name=display_name, summary_description=description, plugin_data=tf.SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content))
def test_experiment_pb(self): hparam_infos = [ api_pb2.HParamInfo(name="param1", display_name="display_name1", description="foo", type=api_pb2.DATA_TYPE_STRING, domain_discrete=struct_pb2.ListValue(values=[ struct_pb2.Value(string_value='a'), struct_pb2.Value(string_value='b') ])), api_pb2.HParamInfo(name="param2", display_name="display_name2", description="bar", type=api_pb2.DATA_TYPE_FLOAT64, domain_interval=api_pb2.Interval( min_value=-100.0, max_value=100.0)) ] metric_infos = [ api_pb2.MetricInfo(name=api_pb2.MetricName(tag="loss"), dataset_type=api_pb2.DATASET_VALIDATION), api_pb2.MetricInfo(name=api_pb2.MetricName(group="train/", tag="acc"), dataset_type=api_pb2.DATASET_TRAINING), ] time_created_secs = 314159.0 self.assertEqual( summary.experiment_pb(hparam_infos, metric_infos, time_created_secs=time_created_secs), tf.Summary(value=[ tf.Summary.Value( tag="_hparams_/experiment", metadata=tf.SummaryMetadata( plugin_data=tf.SummaryMetadata.PluginData( plugin_name="hparams", content=(plugin_data_pb2.HParamsPluginData( version=0, experiment=api_pb2.Experiment( time_created_secs=time_created_secs, hparam_infos=hparam_infos, metric_infos=metric_infos)). SerializeToString())))) ]))
def _create_summary_metadata(): return tf.SummaryMetadata( plugin_data=tf.SummaryMetadata.PluginData( plugin_name=metadata.PLUGIN_NAME))