Ejemplo n.º 1
0
def hparams(hparam_dict=None, metric_dict=None):
    from tensorboardX.proto.plugin_hparams_pb2 import HParamsPluginData, SessionEndInfo, SessionStartInfo
    from tensorboardX.proto.api_pb2 import Experiment, HParamInfo, MetricInfo, MetricName, Status, DataType
    from six import string_types

    PLUGIN_NAME = 'hparams'
    PLUGIN_DATA_VERSION = 0

    EXPERIMENT_TAG = '_hparams_/experiment'
    SESSION_START_INFO_TAG = '_hparams_/session_start_info'
    SESSION_END_INFO_TAG = '_hparams_/session_end_info'

    # TODO: expose other parameters in the future.
    # hp = HParamInfo(name='lr',display_name='learning rate', type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10, max_value=100))  # noqa E501
    # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy', description='', dataset_type=DatasetType.DATASET_VALIDATION)  # noqa E501
    # exp = Experiment(name='123', description='456', time_created_secs=100.0, hparam_infos=[hp], metric_infos=[mt], user='******')  # noqa E501


    hps = []

    ssi = SessionStartInfo()
    for k, v in hparam_dict.items():
        if isinstance(v, string_types):
            ssi.hparams[k].string_value = v
            hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_STRING))
            continue

        if isinstance(v, bool):
            ssi.hparams[k].bool_value = v
            hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_BOOL))
            continue

        if not isinstance(v, int) or not isinstance(v, float):
            v = make_np(v)[0]
            ssi.hparams[k].number_value = v
            hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_FLOAT64))
            continue

        hps.append(HParamInfo(name=k, type=DataType.DATA_TYPE_UNSET))

    content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
    smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME,
                                                                 content=content.SerializeToString()))
    ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])

    mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]

    exp = Experiment(hparam_infos=hps, metric_infos=mts)
    content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
    smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME,
                                                                 content=content.SerializeToString()))
    exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])


    sei = SessionEndInfo(status=Status.STATUS_SUCCESS)
    content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
    smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(plugin_name=PLUGIN_NAME,
                                                                 content=content.SerializeToString()))
    sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
    return exp, ssi, sei
Ejemplo n.º 2
0
 def make_hparam_info(hparam):
     data_type = hparam.get("type")
     if hparam.get("type") is None:
         data_type = DataType.DATA_TYPE_UNSET
     elif hparam.get("type") in string_types:
         data_type = DataType.DATA_TYPE_STRING
     elif hparam.get("type") is bool:
         data_type = DataType.DATA_TYPE_BOOL
     elif hparam.get("type") in (float, int):
         data_type = DataType.DATA_TYPE_FLOAT64
     return HParamInfo(
         name=hparam["name"],
         type=data_type,
         description=hparam.get("description"),
         display_name=hparam.get("display_name"),
         domain_discrete=hparam.get("domain_discrete"),
         domain_interval=hparam.get("domain_interval"),
     )