Пример #1
0
 def test_session_end_pb(self):
     end_time_secs = 1234.0
     self.assertEqual(
         summary.session_end_pb(api_pb2.STATUS_SUCCESS, end_time_secs),
         tf.compat.v1.Summary(
             value=[
                 tf.compat.v1.Summary.Value(
                     tag="_hparams_/session_end_info",
                     tensor=summary._TF_NULL_TENSOR,
                     metadata=tf.compat.v1.SummaryMetadata(
                         plugin_data=tf.compat.v1.SummaryMetadata.PluginData(
                             plugin_name="hparams",
                             content=(
                                 plugin_data_pb2.HParamsPluginData(
                                     version=0,
                                     session_end_info=(
                                         plugin_data_pb2.SessionEndInfo(
                                             status=api_pb2.STATUS_SUCCESS,
                                             end_time_secs=end_time_secs,
                                         )
                                     ),
                                 ).SerializeToString()
                             ),
                         )
                     ),
                 )
             ]
         ),
     )
Пример #2
0
 def test_experiment_pb(self):
     hparam_infos = [
         api_pb2.HParamInfo(
             name="param1",
             display_name="display_name1",
             description="foo",
             type=api_pb2.DATA_TYPE_STRING,
             domain_discrete=struct_pb2.ListValue(values=[
                 struct_pb2.Value(string_value="a"),
                 struct_pb2.Value(string_value="b"),
             ]),
         ),
         api_pb2.HParamInfo(
             name="param2",
             display_name="display_name2",
             description="bar",
             type=api_pb2.DATA_TYPE_FLOAT64,
             domain_interval=api_pb2.Interval(min_value=-100.0,
                                              max_value=100.0),
         ),
     ]
     metric_infos = [
         api_pb2.MetricInfo(
             name=api_pb2.MetricName(tag="loss"),
             dataset_type=api_pb2.DATASET_VALIDATION,
         ),
         api_pb2.MetricInfo(
             name=api_pb2.MetricName(group="train/", tag="acc"),
             dataset_type=api_pb2.DATASET_TRAINING,
         ),
     ]
     time_created_secs = 314159.0
     self.assertEqual(
         summary.experiment_pb(hparam_infos,
                               metric_infos,
                               time_created_secs=time_created_secs),
         tf.compat.v1.Summary(value=[
             tf.compat.v1.Summary.Value(
                 tag="_hparams_/experiment",
                 tensor=summary._TF_NULL_TENSOR,
                 metadata=tf.compat.v1.SummaryMetadata(
                     plugin_data=tf.compat.v1.SummaryMetadata.PluginData(
                         plugin_name="hparams",
                         content=(plugin_data_pb2.HParamsPluginData(
                             version=0,
                             experiment=api_pb2.Experiment(
                                 time_created_secs=time_created_secs,
                                 hparam_infos=hparam_infos,
                                 metric_infos=metric_infos,
                             ),
                         ).SerializeToString()),
                     )),
             )
         ]),
     )