コード例 #1
0
ファイル: hparams_demo.py プロジェクト: zhudatu/tensorboard
def create_experiment_summary():
    """Create an `api_pb2.Experiment` proto describing the experiment."""
    def discrete_domain(values):
        domain = struct_pb2.ListValue()
        domain.extend(values)
        return domain

    hparams = [
        api_pb2.HParamInfo(
            name="conv_layers",
            type=api_pb2.DATA_TYPE_FLOAT64,  # actually int
            domain_discrete=discrete_domain([1, 2, 3]),
        ),
        api_pb2.HParamInfo(
            name="conv_kernel_size",
            type=api_pb2.DATA_TYPE_FLOAT64,  # actually int
            domain_discrete=discrete_domain([3, 5]),
        ),
        api_pb2.HParamInfo(
            name="dense_layers",
            type=api_pb2.DATA_TYPE_FLOAT64,  # actually int
            domain_discrete=discrete_domain([1, 2, 3]),
        ),
        api_pb2.HParamInfo(
            name="dropout",
            type=api_pb2.DATA_TYPE_FLOAT64,
            domain_interval=api_pb2.Interval(min_value=0.1, max_value=0.4),
        ),
        api_pb2.HParamInfo(
            name="optimizer",
            type=api_pb2.DATA_TYPE_STRING,
            domain_discrete=discrete_domain(["adam", "adagrad"]),
        ),
    ]
    metrics = [
        api_pb2.MetricInfo(
            name=api_pb2.MetricName(group="validation", tag="epoch_accuracy"),
            display_name="accuracy (val.)",
        ),
        api_pb2.MetricInfo(
            name=api_pb2.MetricName(group="validation", tag="epoch_loss"),
            display_name="loss (val.)",
        ),
        api_pb2.MetricInfo(
            name=api_pb2.MetricName(group="train", tag="batch_accuracy"),
            display_name="accuracy (train)",
        ),
        api_pb2.MetricInfo(
            name=api_pb2.MetricName(group="train", tag="batch_loss"),
            display_name="loss (train)",
        ),
    ]
    return hparams_summary.experiment_pb(
        hparam_infos=hparams,
        metric_infos=metrics,
    )
コード例 #2
0
ファイル: summary_test.py プロジェクト: jverre/tensorboard-1
 def test_experiment_pb(self):
     hparam_infos = [
         api_pb2.HParamInfo(
             name="param1",
             display_name="display_name1",
             description="foo",
             type=api_pb2.DATA_TYPE_STRING,
             domain_discrete=struct_pb2.ListValue(values=[
                 struct_pb2.Value(string_value="a"),
                 struct_pb2.Value(string_value="b"),
             ]),
         ),
         api_pb2.HParamInfo(
             name="param2",
             display_name="display_name2",
             description="bar",
             type=api_pb2.DATA_TYPE_FLOAT64,
             domain_interval=api_pb2.Interval(min_value=-100.0,
                                              max_value=100.0),
         ),
     ]
     metric_infos = [
         api_pb2.MetricInfo(
             name=api_pb2.MetricName(tag="loss"),
             dataset_type=api_pb2.DATASET_VALIDATION,
         ),
         api_pb2.MetricInfo(
             name=api_pb2.MetricName(group="train/", tag="acc"),
             dataset_type=api_pb2.DATASET_TRAINING,
         ),
     ]
     time_created_secs = 314159.0
     self.assertEqual(
         summary.experiment_pb(hparam_infos,
                               metric_infos,
                               time_created_secs=time_created_secs),
         tf.compat.v1.Summary(value=[
             tf.compat.v1.Summary.Value(
                 tag="_hparams_/experiment",
                 tensor=summary._TF_NULL_TENSOR,
                 metadata=tf.compat.v1.SummaryMetadata(
                     plugin_data=tf.compat.v1.SummaryMetadata.PluginData(
                         plugin_name="hparams",
                         content=(plugin_data_pb2.HParamsPluginData(
                             version=0,
                             experiment=api_pb2.Experiment(
                                 time_created_secs=time_created_secs,
                                 hparam_infos=hparam_infos,
                                 metric_infos=metric_infos,
                             ),
                         ).SerializeToString()),
                     )),
             )
         ]),
     )