def hparams(hparam_dict=None, metric_dict=None): from tensorboardX.proto.plugin_hparams_pb2 import HParamsPluginData, SessionEndInfo, SessionStartInfo from tensorboardX.proto.api_pb2 import Experiment, HParamInfo, MetricInfo, MetricName, Status, DataType from six import string_types PLUGIN_NAME = 'hparams' PLUGIN_DATA_VERSION = 0 EXPERIMENT_TAG = '_hparams_/experiment' SESSION_START_INFO_TAG = '_hparams_/session_start_info' SESSION_END_INFO_TAG = '_hparams_/session_end_info' # TODO: expose other parameters in the future. # hp = HParamInfo(name='lr',display_name='learning rate', type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10, max_value=100)) # noqa E501 # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy', description='', dataset_type=DatasetType.DATASET_VALIDATION) # noqa E501 # exp = Experiment(name='123', description='456', time_created_secs=100.0, hparam_infos=[hp], metric_infos=[mt], user='******') # noqa E501 hps = [] ssi = SessionStartInfo() for k, v in hparam_dict.items(): if isinstance(v, string_types): ssi.hparams[k].string_value = v hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_STRING)) continue if isinstance(v, bool): ssi.hparams[k].bool_value = v hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_BOOL)) continue if not isinstance(v, int) or not isinstance(v, float): v = make_np(v)[0] ssi.hparams[k].number_value = v hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_FLOAT64)) continue hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_UNSET)) content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME, content=content.SerializeToString())) ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] exp = Experiment(hparam_infos=hps, metric_infos=mts) content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME, content=content.SerializeToString())) exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)]) sei = SessionEndInfo(status=Status.STATUS_SUCCESS) content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME, content=content.SerializeToString())) sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)]) return exp, ssi, sei
def make_metric_info(metric): return MetricInfo( name=MetricName(tag=metric['tag']), dataset_type=DatasetType.Value( f'DATASET_{metric.get("dataset_type", "UNKNOWN").upper()}'), description=metric.get('description'), display_name=metric.get('display_name'))