def create_experiment_summary(): """Create an `api_pb2.Experiment` proto describing the experiment.""" def discrete_domain(values): domain = struct_pb2.ListValue() domain.extend(values) return domain hparams = [ api_pb2.HParamInfo( name="conv_layers", type=api_pb2.DATA_TYPE_FLOAT64, # actually int domain_discrete=discrete_domain([1, 2, 3]), ), api_pb2.HParamInfo( name="conv_kernel_size", type=api_pb2.DATA_TYPE_FLOAT64, # actually int domain_discrete=discrete_domain([3, 5]), ), api_pb2.HParamInfo( name="dense_layers", type=api_pb2.DATA_TYPE_FLOAT64, # actually int domain_discrete=discrete_domain([1, 2, 3]), ), api_pb2.HParamInfo( name="dropout", type=api_pb2.DATA_TYPE_FLOAT64, domain_interval=api_pb2.Interval(min_value=0.1, max_value=0.4), ), api_pb2.HParamInfo( name="optimizer", type=api_pb2.DATA_TYPE_STRING, domain_discrete=discrete_domain(["adam", "adagrad"]), ), ] metrics = [ api_pb2.MetricInfo( name=api_pb2.MetricName(group="validation", tag="epoch_accuracy"), display_name="accuracy (val.)", ), api_pb2.MetricInfo( name=api_pb2.MetricName(group="validation", tag="epoch_loss"), display_name="loss (val.)", ), api_pb2.MetricInfo( name=api_pb2.MetricName(group="train", tag="batch_accuracy"), display_name="accuracy (train)", ), api_pb2.MetricInfo( name=api_pb2.MetricName(group="train", tag="batch_loss"), display_name="loss (train)", ), ] return hparams_summary.experiment_pb( hparam_infos=hparams, metric_infos=metrics, )
def test_experiment_pb(self): hparam_infos = [ api_pb2.HParamInfo( name="param1", display_name="display_name1", description="foo", type=api_pb2.DATA_TYPE_STRING, domain_discrete=struct_pb2.ListValue(values=[ struct_pb2.Value(string_value="a"), struct_pb2.Value(string_value="b"), ]), ), api_pb2.HParamInfo( name="param2", display_name="display_name2", description="bar", type=api_pb2.DATA_TYPE_FLOAT64, domain_interval=api_pb2.Interval(min_value=-100.0, max_value=100.0), ), ] metric_infos = [ api_pb2.MetricInfo( name=api_pb2.MetricName(tag="loss"), dataset_type=api_pb2.DATASET_VALIDATION, ), api_pb2.MetricInfo( name=api_pb2.MetricName(group="train/", tag="acc"), dataset_type=api_pb2.DATASET_TRAINING, ), ] time_created_secs = 314159.0 self.assertEqual( summary.experiment_pb(hparam_infos, metric_infos, time_created_secs=time_created_secs), tf.compat.v1.Summary(value=[ tf.compat.v1.Summary.Value( tag="_hparams_/experiment", tensor=summary._TF_NULL_TENSOR, metadata=tf.compat.v1.SummaryMetadata( plugin_data=tf.compat.v1.SummaryMetadata.PluginData( plugin_name="hparams", content=(plugin_data_pb2.HParamsPluginData( version=0, experiment=api_pb2.Experiment( time_created_secs=time_created_secs, hparam_infos=hparam_infos, metric_infos=metric_infos, ), ).SerializeToString()), )), ) ]), )