Exemplo n.º 1
0
class PipelineSpecBuilderTest(parameterized.TestCase):
    def setUp(self):
        self.maxDiff = None

    @parameterized.parameters(
        {
            'channel':
            pipeline_channel.PipelineParameterChannel(
                name='output1', task_name='task1', channel_type='String'),
            'expected':
            'pipelinechannel--task1-output1',
        },
        {
            'channel':
            pipeline_channel.PipelineArtifactChannel(
                name='output1', task_name='task1', channel_type='Artifact'),
            'expected':
            'pipelinechannel--task1-output1',
        },
        {
            'channel':
            pipeline_channel.PipelineParameterChannel(name='param1',
                                                      channel_type='String'),
            'expected':
            'pipelinechannel--param1',
        },
    )
    def test_additional_input_name_for_pipeline_channel(
            self, channel, expected):
        self.assertEqual(
            expected,
            pipeline_spec_builder._additional_input_name_for_pipeline_channel(
                channel))

    @parameterized.parameters(
        {
            'parameter_type': pipeline_spec_pb2.ParameterType.NUMBER_INTEGER,
            'default_value': None,
            'expected': struct_pb2.Value(),
        },
        {
            'parameter_type': pipeline_spec_pb2.ParameterType.NUMBER_INTEGER,
            'default_value': 1,
            'expected': struct_pb2.Value(number_value=1),
        },
        {
            'parameter_type': pipeline_spec_pb2.ParameterType.NUMBER_DOUBLE,
            'default_value': 1.2,
            'expected': struct_pb2.Value(number_value=1.2),
        },
        {
            'parameter_type': pipeline_spec_pb2.ParameterType.STRING,
            'default_value': 'text',
            'expected': struct_pb2.Value(string_value='text'),
        },
        {
            'parameter_type': pipeline_spec_pb2.ParameterType.BOOLEAN,
            'default_value': True,
            'expected': struct_pb2.Value(bool_value=True),
        },
        {
            'parameter_type': pipeline_spec_pb2.ParameterType.BOOLEAN,
            'default_value': False,
            'expected': struct_pb2.Value(bool_value=False),
        },
        {
            'parameter_type':
            pipeline_spec_pb2.ParameterType.STRUCT,
            'default_value': {
                'a': 1,
                'b': 2,
            },
            'expected':
            struct_pb2.Value(struct_value=struct_pb2.Struct(
                fields={
                    'a': struct_pb2.Value(number_value=1),
                    'b': struct_pb2.Value(number_value=2),
                })),
        },
        {
            'parameter_type':
            pipeline_spec_pb2.ParameterType.LIST,
            'default_value': ['a', 'b'],
            'expected':
            struct_pb2.Value(list_value=struct_pb2.ListValue(values=[
                struct_pb2.Value(string_value='a'),
                struct_pb2.Value(string_value='b'),
            ])),
        },
        {
            'parameter_type':
            pipeline_spec_pb2.ParameterType.LIST,
            'default_value': [{
                'a': 1,
                'b': 2
            }, {
                'a': 10,
                'b': 20
            }],
            'expected':
            struct_pb2.Value(list_value=struct_pb2.ListValue(values=[
                struct_pb2.Value(struct_value=struct_pb2.Struct(
                    fields={
                        'a': struct_pb2.Value(number_value=1),
                        'b': struct_pb2.Value(number_value=2),
                    })),
                struct_pb2.Value(struct_value=struct_pb2.Struct(
                    fields={
                        'a': struct_pb2.Value(number_value=10),
                        'b': struct_pb2.Value(number_value=20),
                    })),
            ])),
        },
    )
    def test_fill_in_component_input_default_value(self, parameter_type,
                                                   default_value, expected):
        component_spec = pipeline_spec_pb2.ComponentSpec(
            input_definitions=pipeline_spec_pb2.ComponentInputsSpec(
                parameters={
                    'input1':
                    pipeline_spec_pb2.ComponentInputsSpec.ParameterSpec(
                        parameter_type=parameter_type)
                }))
        pipeline_spec_builder._fill_in_component_input_default_value(
            component_spec=component_spec,
            input_name='input1',
            default_value=default_value)

        self.assertEqual(
            expected,
            component_spec.input_definitions.parameters['input1'].
            default_value,
        )
    def testStruct(self):
        struct = struct_pb2.Struct()
        self.assertIsInstance(struct, collections.Mapping)
        self.assertEqual(0, len(struct))
        struct_class = struct.__class__

        struct['key1'] = 5
        struct['key2'] = 'abc'
        struct['key3'] = True
        struct.get_or_create_struct('key4')['subkey'] = 11.0
        struct_list = struct.get_or_create_list('key5')
        self.assertIsInstance(struct_list, collections.Sequence)
        struct_list.extend([6, 'seven', True, False, None])
        struct_list.add_struct()['subkey2'] = 9
        struct['key6'] = {'subkey': {}}
        struct['key7'] = [2, False]

        self.assertEqual(7, len(struct))
        self.assertTrue(isinstance(struct, well_known_types.Struct))
        self.assertEqual(5, struct['key1'])
        self.assertEqual('abc', struct['key2'])
        self.assertIs(True, struct['key3'])
        self.assertEqual(11, struct['key4']['subkey'])
        inner_struct = struct_class()
        inner_struct['subkey2'] = 9
        self.assertEqual([6, 'seven', True, False, None, inner_struct],
                         list(struct['key5'].items()))
        self.assertEqual({}, dict(struct['key6']['subkey'].fields))
        self.assertEqual([2, False], list(struct['key7'].items()))

        serialized = struct.SerializeToString()
        struct2 = struct_pb2.Struct()
        struct2.ParseFromString(serialized)

        self.assertEqual(struct, struct2)
        for key, value in struct.items():
            self.assertIn(key, struct)
            self.assertIn(key, struct2)
            self.assertEqual(value, struct2[key])

        self.assertEqual(7, len(struct.keys()))
        self.assertEqual(7, len(struct.values()))
        for key in struct.keys():
            self.assertIn(key, struct)
            self.assertIn(key, struct2)
            self.assertEqual(struct[key], struct2[key])

        item = (next(iter(struct.keys())), next(iter(struct.values())))
        self.assertEqual(item, next(iter(struct.items())))

        self.assertTrue(isinstance(struct2, well_known_types.Struct))
        self.assertEqual(5, struct2['key1'])
        self.assertEqual('abc', struct2['key2'])
        self.assertIs(True, struct2['key3'])
        self.assertEqual(11, struct2['key4']['subkey'])
        self.assertEqual([6, 'seven', True, False, None, inner_struct],
                         list(struct2['key5'].items()))

        struct_list = struct2['key5']
        self.assertEqual(6, struct_list[0])
        self.assertEqual('seven', struct_list[1])
        self.assertEqual(True, struct_list[2])
        self.assertEqual(False, struct_list[3])
        self.assertEqual(None, struct_list[4])
        self.assertEqual(inner_struct, struct_list[5])

        struct_list[1] = 7
        self.assertEqual(7, struct_list[1])

        struct_list.add_list().extend([1, 'two', True, False, None])
        self.assertEqual([1, 'two', True, False, None],
                         list(struct_list[6].items()))
        struct_list.extend([{
            'nested_struct': 30
        }, ['nested_list', 99], {}, []])
        self.assertEqual(11, len(struct_list.values))
        self.assertEqual(30, struct_list[7]['nested_struct'])
        self.assertEqual('nested_list', struct_list[8][0])
        self.assertEqual(99, struct_list[8][1])
        self.assertEqual({}, dict(struct_list[9].fields))
        self.assertEqual([], list(struct_list[10].items()))
        struct_list[0] = {'replace': 'set'}
        struct_list[1] = ['replace', 'set']
        self.assertEqual('set', struct_list[0]['replace'])
        self.assertEqual(['replace', 'set'], list(struct_list[1].items()))

        text_serialized = str(struct)
        struct3 = struct_pb2.Struct()
        text_format.Merge(text_serialized, struct3)
        self.assertEqual(struct, struct3)

        struct.get_or_create_struct('key3')['replace'] = 12
        self.assertEqual(12, struct['key3']['replace'])

        # Tests empty list.
        struct.get_or_create_list('empty_list')
        empty_list = struct['empty_list']
        self.assertEqual([], list(empty_list.items()))
        list2 = struct_pb2.ListValue()
        list2.add_list()
        empty_list = list2[0]
        self.assertEqual([], list(empty_list.items()))

        # Tests empty struct.
        struct.get_or_create_struct('empty_struct')
        empty_struct = struct['empty_struct']
        self.assertEqual({}, dict(empty_struct.fields))
        list2.add_struct()
        empty_struct = list2[1]
        self.assertEqual({}, dict(empty_struct.fields))

        self.assertEqual(9, len(struct))
        del struct['key3']
        del struct['key4']
        self.assertEqual(7, len(struct))
        self.assertEqual(6, len(struct['key5']))
        del struct['key5'][1]
        self.assertEqual(5, len(struct['key5']))
        self.assertEqual([6, True, False, None, inner_struct],
                         list(struct['key5'].items()))
Exemplo n.º 3
0
 def discrete_domain(values):
     domain = struct_pb2.ListValue()
     domain.extend(values)
     return domain
Exemplo n.º 4
0
def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
    """Outputs three `Summary` protocol buffers needed by hparams plugin.
    `Experiment` keeps the metadata of an experiment, such as the name of the
      hyperparameters and the name of the metrics.
    `SessionStartInfo` keeps key-value pairs of the hyperparameters
    `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS

    Args:
      hparam_dict: A dictionary that contains names of the hyperparameters
        and their values.
      metric_dict: A dictionary that contains names of the metrics
        and their values.
      hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
        contains names of the hyperparameters and all discrete values they can hold

    Returns:
      The `Summary` protobufs for Experiment, SessionStartInfo and
        SessionEndInfo
    """
    import torch
    from six import string_types
    from tensorboard.plugins.hparams.api_pb2 import (
        Experiment, HParamInfo, MetricInfo, MetricName, Status, DataType
    )
    from tensorboard.plugins.hparams.metadata import (
        PLUGIN_NAME,
        PLUGIN_DATA_VERSION,
        EXPERIMENT_TAG,
        SESSION_START_INFO_TAG,
        SESSION_END_INFO_TAG
    )
    from tensorboard.plugins.hparams.plugin_data_pb2 import (
        HParamsPluginData, SessionEndInfo, SessionStartInfo
    )

    # TODO: expose other parameters in the future.
    # hp = HParamInfo(name='lr',display_name='learning rate',
    # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
    # max_value=100))
    # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
    # description='', dataset_type=DatasetType.DATASET_VALIDATION)
    # exp = Experiment(name='123', description='456', time_created_secs=100.0,
    # hparam_infos=[hp], metric_infos=[mt], user='******')

    if not isinstance(hparam_dict, dict):
        logging.warning('parameter: hparam_dict should be a dictionary, nothing logged.')
        raise TypeError('parameter: hparam_dict should be a dictionary, nothing logged.')
    if not isinstance(metric_dict, dict):
        logging.warning('parameter: metric_dict should be a dictionary, nothing logged.')
        raise TypeError('parameter: metric_dict should be a dictionary, nothing logged.')

    hparam_domain_discrete = hparam_domain_discrete or {}
    if not isinstance(hparam_domain_discrete, dict):
        raise TypeError(
            "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
        )
    for k, v in hparam_domain_discrete.items():
        if (
            k not in hparam_dict
            or not isinstance(v, list)
            or not all(isinstance(d, type(hparam_dict[k])) for d in v)
        ):
            raise TypeError(
                "parameter: hparam_domain_discrete[{}] should be a list of same type as "
                "hparam_dict[{}].".format(k, k)
            )
    hps = []


    ssi = SessionStartInfo()
    for k, v in hparam_dict.items():
        if v is None:
            continue
        if isinstance(v, int) or isinstance(v, float):
            ssi.hparams[k].number_value = v

            if k in hparam_domain_discrete:
                domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
                    values=[
                        struct_pb2.Value(number_value=d)
                        for d in hparam_domain_discrete[k]
                    ]
                )
            else:
                domain_discrete = None

            hps.append(
                HParamInfo(
                    name=k,
                    type=DataType.Value("DATA_TYPE_FLOAT64"),
                    domain_discrete=domain_discrete,
                )
            )
            continue

        if isinstance(v, string_types):
            ssi.hparams[k].string_value = v

            if k in hparam_domain_discrete:
                domain_discrete = struct_pb2.ListValue(
                    values=[
                        struct_pb2.Value(string_value=d)
                        for d in hparam_domain_discrete[k]
                    ]
                )
            else:
                domain_discrete = None

            hps.append(
                HParamInfo(
                    name=k,
                    type=DataType.Value("DATA_TYPE_STRING"),
                    domain_discrete=domain_discrete,
                )
            )
            continue

        if isinstance(v, bool):
            ssi.hparams[k].bool_value = v

            if k in hparam_domain_discrete:
                domain_discrete = struct_pb2.ListValue(
                    values=[
                        struct_pb2.Value(bool_value=d)
                        for d in hparam_domain_discrete[k]
                    ]
                )
            else:
                domain_discrete = None

            hps.append(
                HParamInfo(
                    name=k,
                    type=DataType.Value("DATA_TYPE_BOOL"),
                    domain_discrete=domain_discrete,
                )
            )
            continue

        if isinstance(v, torch.Tensor):
            v = make_np(v)[0]
            ssi.hparams[k].number_value = v
            hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
            continue
        raise ValueError('value should be one of int, float, str, bool, or torch.Tensor')

    content = HParamsPluginData(session_start_info=ssi,
                                version=PLUGIN_DATA_VERSION)
    smd = SummaryMetadata(
        plugin_data=SummaryMetadata.PluginData(
            plugin_name=PLUGIN_NAME,
            content=content.SerializeToString()
        )
    )
    ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])

    mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]

    exp = Experiment(hparam_infos=hps, metric_infos=mts)

    content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
    smd = SummaryMetadata(
        plugin_data=SummaryMetadata.PluginData(
            plugin_name=PLUGIN_NAME,
            content=content.SerializeToString()
        )
    )
    exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])

    sei = SessionEndInfo(status=Status.Value('STATUS_SUCCESS'))
    content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
    smd = SummaryMetadata(
        plugin_data=SummaryMetadata.PluginData(
            plugin_name=PLUGIN_NAME,
            content=content.SerializeToString()
        )
    )
    sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])

    return exp, ssi, sei
Exemplo n.º 5
0
 def test_experiment_pb(self):
     hparam_infos = [
         api_pb2.HParamInfo(
             name="param1",
             display_name="display_name1",
             description="foo",
             type=api_pb2.DATA_TYPE_STRING,
             domain_discrete=struct_pb2.ListValue(
                 values=[
                     struct_pb2.Value(string_value="a"),
                     struct_pb2.Value(string_value="b"),
                 ]
             ),
         ),
         api_pb2.HParamInfo(
             name="param2",
             display_name="display_name2",
             description="bar",
             type=api_pb2.DATA_TYPE_FLOAT64,
             domain_interval=api_pb2.Interval(
                 min_value=-100.0, max_value=100.0
             ),
         ),
     ]
     metric_infos = [
         api_pb2.MetricInfo(
             name=api_pb2.MetricName(tag="loss"),
             dataset_type=api_pb2.DATASET_VALIDATION,
         ),
         api_pb2.MetricInfo(
             name=api_pb2.MetricName(group="train/", tag="acc"),
             dataset_type=api_pb2.DATASET_TRAINING,
         ),
     ]
     time_created_secs = 314159.0
     self.assertEqual(
         summary.experiment_pb(
             hparam_infos, metric_infos, time_created_secs=time_created_secs
         ),
         tf.compat.v1.Summary(
             value=[
                 tf.compat.v1.Summary.Value(
                     tag="_hparams_/experiment",
                     tensor=summary._TF_NULL_TENSOR,
                     metadata=tf.compat.v1.SummaryMetadata(
                         plugin_data=tf.compat.v1.SummaryMetadata.PluginData(
                             plugin_name="hparams",
                             content=(
                                 plugin_data_pb2.HParamsPluginData(
                                     version=0,
                                     experiment=api_pb2.Experiment(
                                         time_created_secs=time_created_secs,
                                         hparam_infos=hparam_infos,
                                         metric_infos=metric_infos,
                                     ),
                                 ).SerializeToString()
                             ),
                         )
                     ),
                 )
             ]
         ),
     )