Esempio n. 1
0
    def testPluginTagToContent_PluginsCannotJumpOnTheBandwagon(self):
        # If there are multiple `SummaryMetadata` for a given tag, and the
        # set of plugins in the `plugin_data` of second is different from
        # that of the first, then the second set should be ignored.
        logdir = self.get_temp_dir()
        summary_metadata_1 = tf.SummaryMetadata(
            display_name='current tagee',
            summary_description='no',
            plugin_data=tf.SummaryMetadata.PluginData(plugin_name='outlet',
                                                      content=b'120v'))
        self._writeMetadata(logdir, summary_metadata_1, nonce='1')
        acc = ea.EventAccumulator(logdir)
        acc.Reload()
        summary_metadata_2 = tf.SummaryMetadata(
            display_name='tagee of the future',
            summary_description='definitely not',
            plugin_data=tf.SummaryMetadata.PluginData(plugin_name='plug',
                                                      content=b'110v'))
        self._writeMetadata(logdir, summary_metadata_2, nonce='2')
        acc.Reload()

        self.assertEqual(acc.PluginTagToContent('outlet'),
                         {'you_are_it': b'120v'})
        with six.assertRaisesRegex(self, KeyError, 'plug'):
            acc.PluginTagToContent('plug')
Esempio n. 2
0
 def test_session_start_pb(self):
     start_time_secs = 314160
     session_start_info = plugin_data_pb2.SessionStartInfo(
         model_uri="//model/uri",
         group_name="session_group",
         start_time_secs=start_time_secs)
     session_start_info.hparams["param1"].string_value = "string"
     # TODO: Fix nondeterminism.
     # session_start_info.hparams["param2"].number_value = 5.0
     # session_start_info.hparams["param3"].bool_value = False
     self.assertEqual(
         summary.session_start_pb(
             hparams={
                 "param1": "string",
                 # "param2":5,
                 # "param3":False,
             },
             model_uri="//model/uri",
             group_name="session_group",
             start_time_secs=start_time_secs),
         tf.Summary(value=[
             tf.Summary.Value(
                 tag="_hparams_/session_start_info",
                 metadata=tf.SummaryMetadata(
                     plugin_data=tf.SummaryMetadata.PluginData(
                         plugin_name="hparams",
                         content=(plugin_data_pb2.HParamsPluginData(
                             version=0,
                             session_start_info=session_start_info).
                                  SerializeToString()))))
         ]))
Esempio n. 3
0
    def add_3dvolume(self, volume, tag, global_step=None, walltime=None):
        filename = tag + "_"
        if global_step is None:
            filename += datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        else:
            filename += str(global_step)

        if isinstance(volume, torch.Tensor):
            volume = volume.detach().cpu().numpy()

        img = ants.from_numpy(volume)
        ants.image_write(img, os.path.join(self._log_dir,
                                           filename + ".nii.gz"))

        plugin_data = tf.SummaryMetadata.PluginData(
            plugin_name="tb_3d_volume_plugin",
            content=TextPluginData(version=0).SerializeToString())
        metadata = tf.SummaryMetadata(plugin_data=plugin_data)
        tensor = TensorProto(
            dtype='DT_STRING',
            string_val=[filename.encode(encoding='utf_8')],
            tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]))
        summary = summary_pb2.Summary(value=[
            summary_pb2.Summary.Value(
                tag=tag, metadata=metadata, tensor=tensor)
        ])
        self._file_writer.add_summary(summary,
                                      global_step=global_step,
                                      walltime=walltime)
        self._file_writer.flush()
Esempio n. 4
0
  def _testTFSummaryTensor_SizeGuidance(self,
                                        plugin_name,
                                        tensor_size_guidance,
                                        steps,
                                        expected_count):
    event_sink = _EventGenerator(self, zero_out_timestamps=True)
    writer = tf.summary.FileWriter(self.get_temp_dir())
    writer.event_writer = event_sink
    with self.test_session() as sess:
      summary_metadata = tf.SummaryMetadata(
          plugin_data=tf.SummaryMetadata.PluginData(plugin_name=plugin_name,
                                                    content=b'{}'))
      tf.summary.tensor_summary('scalar', tf.constant(1.0),
                                summary_metadata=summary_metadata)
      merged = tf.summary.merge_all()
      for step in xrange(steps):
        writer.add_summary(sess.run(merged), global_step=step)


    accumulator = ea.EventAccumulator(
        event_sink, tensor_size_guidance=tensor_size_guidance)
    accumulator.Reload()

    tensors = accumulator.Tensors('scalar')
    self.assertEqual(len(tensors), expected_count)
Esempio n. 5
0
 def testSummaryMetadata(self):
     logdir = self.get_temp_dir()
     summary_metadata = tf.SummaryMetadata(display_name='current tagee',
                                           summary_description='no')
     summary_metadata.plugin_data.plugin_name = 'outlet'
     self._writeMetadata(logdir, summary_metadata)
     acc = ea.EventAccumulator(logdir)
     acc.Reload()
     self.assertProtoEquals(summary_metadata,
                            acc.SummaryMetadata('you_are_it'))
Esempio n. 6
0
    def testSummaryMetadata_FirstMetadataWins(self):
        logdir = self.get_temp_dir()
        summary_metadata_1 = tf.SummaryMetadata(
            display_name='current tagee',
            summary_description='no',
            plugin_data=tf.SummaryMetadata.PluginData(plugin_name='outlet',
                                                      content=b'120v'))
        self._writeMetadata(logdir, summary_metadata_1, nonce='1')
        acc = ea.EventAccumulator(logdir)
        acc.Reload()
        summary_metadata_2 = tf.SummaryMetadata(
            display_name='tagee of the future',
            summary_description='definitely not',
            plugin_data=tf.SummaryMetadata.PluginData(plugin_name='plug',
                                                      content=b'110v'))
        self._writeMetadata(logdir, summary_metadata_2, nonce='2')
        acc.Reload()

        self.assertProtoEquals(summary_metadata_1,
                               acc.SummaryMetadata('you_are_it'))
Esempio n. 7
0
def create_summary_metadata(display_name, description):
    """Create a `tf.SummaryMetadata` proto for text plugin data.
  Returns:
    A `tf.SummaryMetadata` protobuf object.
  """
    content = plugin_data_pb2.TextPluginData(version=PROTO_VERSION)
    metadata = tf.SummaryMetadata(display_name=display_name,
                                  summary_description=description,
                                  plugin_data=tf.SummaryMetadata.PluginData(
                                      plugin_name=PLUGIN_NAME,
                                      content=content.SerializeToString()))
    return metadata
Esempio n. 8
0
 def test_fully_populated_tensor(self):
   metadata = tf.SummaryMetadata(
       plugin_data=tf.SummaryMetadata.PluginData(
           plugin_name='font_of_wisdom',
           content=b'adobe_garamond'))
   op = tf.summary.tensor_summary(
       name='tensorpocalypse',
       tensor=tf.constant([[0.0, 2.0], [float('inf'), float('nan')]]),
       display_name='TENSORPOCALYPSE',
       summary_description='look on my works ye mighty and despair',
       summary_metadata=metadata)
   value = self._value_from_op(op)
   assert value.HasField('tensor'), value
   self._assert_noop(value)
Esempio n. 9
0
def create_summary_metadata(hparams_plugin_data_pb):
    """Creates a tf.SummaryMetadata holding a copy of the given
  HParamsPluginData message in its plugin_data.content field.
  Sets the version field of the hparams_plugin_data_pb copy to
  PLUGIN_DATA_VERSION.
  """
    if not isinstance(hparams_plugin_data_pb,
                      plugin_data_pb2.HParamsPluginData):
        raise TypeError(
            'Needed an instance of plugin_data_pb2.HParamsPluginData.'
            ' Got: %s' % type(hparams_plugin_data_pb))
    content = plugin_data_pb2.HParamsPluginData()
    content.CopyFrom(hparams_plugin_data_pb)
    content.version = PLUGIN_DATA_VERSION
    return tf.SummaryMetadata(plugin_data=tf.SummaryMetadata.PluginData(
        plugin_name=PLUGIN_NAME, content=content.SerializeToString()))
Esempio n. 10
0
 def test_session_end_pb(self):
     end_time_secs = 1234.0
     self.assertEqual(
         summary.session_end_pb(api_pb2.STATUS_SUCCESS, end_time_secs),
         tf.Summary(value=[
             tf.Summary.Value(
                 tag="_hparams_/session_end_info",
                 metadata=tf.SummaryMetadata(
                     plugin_data=tf.SummaryMetadata.PluginData(
                         plugin_name="hparams",
                         content=(plugin_data_pb2.HParamsPluginData(
                             version=0,
                             session_end_info=(
                                 plugin_data_pb2.SessionEndInfo(
                                     status=api_pb2.STATUS_SUCCESS,
                                     end_time_secs=end_time_secs,
                                 ))).SerializeToString()))))
         ]))
Esempio n. 11
0
def create_summary_metadata(display_name, description, num_thresholds):
    """Create a `tf.SummaryMetadata` proto for pr_curves plugin data.

  Arguments:
    display_name: The display name used in TensorBoard.
    description: The description to show in TensorBoard.
    num_thresholds: The number of thresholds to use for PR curves.

  Returns:
    A `tf.SummaryMetadata` protobuf object.
  """
    pr_curve_plugin_data = plugin_data_pb2.PrCurvePluginData(
        version=PROTO_VERSION, num_thresholds=num_thresholds)
    content = pr_curve_plugin_data.SerializeToString()
    return tf.SummaryMetadata(display_name=display_name,
                              summary_description=description,
                              plugin_data=tf.SummaryMetadata.PluginData(
                                  plugin_name=PLUGIN_NAME, content=content))
Esempio n. 12
0
 def test_experiment_pb(self):
     hparam_infos = [
         api_pb2.HParamInfo(name="param1",
                            display_name="display_name1",
                            description="foo",
                            type=api_pb2.DATA_TYPE_STRING,
                            domain_discrete=struct_pb2.ListValue(values=[
                                struct_pb2.Value(string_value='a'),
                                struct_pb2.Value(string_value='b')
                            ])),
         api_pb2.HParamInfo(name="param2",
                            display_name="display_name2",
                            description="bar",
                            type=api_pb2.DATA_TYPE_FLOAT64,
                            domain_interval=api_pb2.Interval(
                                min_value=-100.0, max_value=100.0))
     ]
     metric_infos = [
         api_pb2.MetricInfo(name=api_pb2.MetricName(tag="loss"),
                            dataset_type=api_pb2.DATASET_VALIDATION),
         api_pb2.MetricInfo(name=api_pb2.MetricName(group="train/",
                                                    tag="acc"),
                            dataset_type=api_pb2.DATASET_TRAINING),
     ]
     time_created_secs = 314159.0
     self.assertEqual(
         summary.experiment_pb(hparam_infos,
                               metric_infos,
                               time_created_secs=time_created_secs),
         tf.Summary(value=[
             tf.Summary.Value(
                 tag="_hparams_/experiment",
                 metadata=tf.SummaryMetadata(
                     plugin_data=tf.SummaryMetadata.PluginData(
                         plugin_name="hparams",
                         content=(plugin_data_pb2.HParamsPluginData(
                             version=0,
                             experiment=api_pb2.Experiment(
                                 time_created_secs=time_created_secs,
                                 hparam_infos=hparam_infos,
                                 metric_infos=metric_infos)).
                                  SerializeToString()))))
         ]))
Esempio n. 13
0
def _create_summary_metadata():
  return tf.SummaryMetadata(
      plugin_data=tf.SummaryMetadata.PluginData(
          plugin_name=metadata.PLUGIN_NAME))