def testToFromProto(self): outer = _params.Params() outer.Define('integer_val', 1, '') outer.Define('cls_type', type(int), '') inner = _params.Params() inner.Define('float_val', 2.71, '') inner.Define('string_val', 'rosalie et adrien', '') inner.Define('bool_val', True, '') inner.Define('list_of_tuples_of_dicts', [({'string_key': 1729})], '') inner.Define('range', range(1, 3), '') outer.Define('inner', inner, '') outer.Define('empty_list', [], '') outer.Define('empty_tuple', (), '') outer.Define('empty_dict', {}, '') outer.Define('enum', TestEnum.B, '') outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '') rebuilt_outer = _params.InstantiableParams.FromProto(outer.ToProto()) self.assertEqual(outer.integer_val, rebuilt_outer.integer_val) self.assertEqual(outer.cls_type, rebuilt_outer.cls_type) self.assertNear(outer.inner.float_val, rebuilt_outer.inner.float_val, 1e-6) self.assertEqual(outer.inner.string_val, rebuilt_outer.inner.string_val) self.assertEqual(outer.inner.bool_val, rebuilt_outer.inner.bool_val) self.assertEqual(outer.inner.list_of_tuples_of_dicts, rebuilt_outer.inner.list_of_tuples_of_dicts) self.assertEqual([1, 2], rebuilt_outer.inner.range) # Rebuilt as list. self.assertEqual(outer.empty_list, rebuilt_outer.empty_list) self.assertEqual(outer.empty_tuple, rebuilt_outer.empty_tuple) self.assertEqual(outer.empty_dict, rebuilt_outer.empty_dict) self.assertEqual(outer.enum, rebuilt_outer.enum) self.assertEqual(outer.proto, rebuilt_outer.proto)
def _ToParamValue(val): """Serializes to HyperparamValue proto.""" param_pb = hyperparams_pb2.HyperparamValue() if isinstance(val, Params): param_pb.param_val.CopyFrom(_ToParam(val)) elif isinstance(val, list): for v in val: param_pb.list_val.items.append(_ToParamValue(v)) elif isinstance(val, tuple): for v in val: param_pb.tuple_val.items.append(_ToParamValue(v)) elif isinstance(val, dict): for k, v in val.items(): param_pb.dict_val.items[k].CopyFrom(_ToParamValue(v)) elif isinstance(val, type): param_pb.type_val = inspect.getmodule(val).__name__ + '/' + val.__name__ elif isinstance(val, tf.DType): param_pb.dtype_val = val.name elif isinstance(val, str): param_pb.string_val = val elif isinstance(val, bool): param_pb.bool_val = val elif isinstance(val, six.integer_types): param_pb.int_val = val elif isinstance(val, float): param_pb.float_val = val elif val is None: # We represent a NoneType by the absence of any of the oneof. pass else: raise AttributeError('Unsupported type: %s' % type(val)) return param_pb
def _ToParamValue(key: str, val: Any) -> hyperparams_pb2.HyperparamValue: """Serializes to HyperparamValue proto.""" param_pb = hyperparams_pb2.HyperparamValue() if isinstance(val, Params): param_pb.param_val.CopyFrom(_ToParam(val, prefix=key)) elif isinstance(val, list) or isinstance(val, range): # The range function is serialized by explicitely calling it. param_pb.list_val.items.extend([ _ToParamValue(f'{key}[{i}]', v) for i, v in enumerate(val) ]) elif isinstance(val, tuple): param_pb.tuple_val.items.extend([ _ToParamValue(f'{key}[{i}]', v) for i, v in enumerate(val) ]) elif dataclasses.is_dataclass(val) or _IsNamedTuple(val): val_cls = type(val) items = val.__dict__.items() if dataclasses.is_dataclass( val) else val._asdict().items() param_pb.named_tuple_val.type = inspect.getmodule( val_cls).__name__ + '/' + val_cls.__name__ param_pb.named_tuple_val.items.extend( [_ToParamValue(f'{key}[{k}]', v) for k, v in items]) # Only dicts where all keys are str can be stored as dict_val. elif isinstance(val, dict) and all( isinstance(k, str) for k in val): param_pb.dict_val.SetInParent() for k, v in val.items(): param_pb.dict_val.items[k].CopyFrom( _ToParamValue(f'{key}[{k}]', v)) elif isinstance(val, type): param_pb.type_val = inspect.getmodule( val).__name__ + '/' + val.__name__ elif isinstance(val, tf.DType): param_pb.dtype_val = val.name elif isinstance(val, str): param_pb.string_val = val elif isinstance(val, bool): param_pb.bool_val = val elif isinstance(val, int): param_pb.int_val = val elif isinstance(val, float): param_pb.float_val = val elif isinstance(val, enum.Enum): enum_cls = type(val) param_pb.enum_val.type = inspect.getmodule( enum_cls).__name__ + '/' + enum_cls.__name__ param_pb.enum_val.name = val.name elif isinstance(val, message.Message): proto_cls = type(val) param_pb.proto_val.type = inspect.getmodule( proto_cls).__name__ + '/' + proto_cls.__name__ param_pb.proto_val.val = val.SerializeToString() elif val is None: # We represent a NoneType by the absence of any of the oneof. pass else: param_pb.string_repr_val = repr(val) return param_pb
def _ToParamValue(val): """Serializes to HyperparamValue proto.""" param_pb = hyperparams_pb2.HyperparamValue() if isinstance(val, Params): param_pb.param_val.CopyFrom(_ToParam(val)) elif isinstance(val, list) or isinstance(val, range): # The range function is serialized by explicitely calling it. param_pb.list_val.items.extend([_ToParamValue(v) for v in val]) elif dataclasses.is_dataclass(val) or _IsNamedTuple(val): val_cls = type(val) vals = val.__dict__.values() if dataclasses.is_dataclass( val) else val._asdict().values() param_pb.named_tuple_val.type = inspect.getmodule( val_cls).__name__ + '/' + val_cls.__name__ param_pb.named_tuple_val.items.extend( [_ToParamValue(v) for v in vals]) elif isinstance(val, tuple): param_pb.tuple_val.items.extend( [_ToParamValue(v) for v in val]) elif isinstance(val, dict): param_pb.dict_val.SetInParent() for k, v in val.items(): param_pb.dict_val.items[k].CopyFrom(_ToParamValue(v)) elif isinstance(val, type): param_pb.type_val = inspect.getmodule( val).__name__ + '/' + val.__name__ elif isinstance(val, tf.DType): param_pb.dtype_val = val.name elif isinstance(val, str): param_pb.string_val = val elif isinstance(val, bool): param_pb.bool_val = val elif isinstance(val, int): param_pb.int_val = val elif isinstance(val, float): param_pb.float_val = val elif isinstance(val, enum.Enum): enum_cls = type(val) param_pb.enum_val.type = inspect.getmodule( enum_cls).__name__ + '/' + enum_cls.__name__ param_pb.enum_val.name = val.name elif isinstance(val, message.Message): proto_cls = type(val) param_pb.proto_val.type = inspect.getmodule( proto_cls).__name__ + '/' + proto_cls.__name__ param_pb.proto_val.val = val.SerializeToString() elif val is None: # We represent a NoneType by the absence of any of the oneof. pass else: raise AttributeError('Unsupported type: %s for value %s' % (type(val), val)) return param_pb
def testToFromProto(self): outer = hyperparams.Params() outer.Define('integer_val', 1, '') outer.Define('cls_type', type(int), '') inner = hyperparams.Params() inner.Define('float_val', 2.71, '') inner.Define('string_val', 'rosalie et adrien', '') inner.Define('bool_val', True, '') inner.Define('list_of_tuples_of_dicts', [({'string_key': 1729})], '') inner.Define('range', range(1, 3), '') outer.Define('inner', inner, '') outer.Define('empty_list', [], '') outer.Define('empty_tuple', (), '') outer.Define('empty_dict', {}, '') outer.Define('enum', TestEnum.B, '') outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '') outer.Define('dataclass', TestDataClass(a=[42], b=tf.float32), '') outer.Define('namedtuple', tf.io.FixedLenSequenceFeature([42], tf.float32), '') outer.Define('symbol_x', symbolic.Symbol('x'), '') outer.Define('symbol_2x', outer.symbol_x * 2, '') rebuilt_outer = hyperparams.InstantiableParams.FromProto( outer.ToProto()) self.assertNotIn('cls', rebuilt_outer) self.assertEqual(outer.integer_val, rebuilt_outer.integer_val) self.assertEqual(outer.cls_type, rebuilt_outer.cls_type) self.assertNear(outer.inner.float_val, rebuilt_outer.inner.float_val, 1e-6) self.assertEqual(outer.inner.string_val, rebuilt_outer.inner.string_val) self.assertEqual(outer.inner.bool_val, rebuilt_outer.inner.bool_val) self.assertEqual(outer.inner.list_of_tuples_of_dicts, rebuilt_outer.inner.list_of_tuples_of_dicts) self.assertEqual([1, 2], rebuilt_outer.inner.range) # Rebuilt as list. self.assertEqual(outer.empty_list, rebuilt_outer.empty_list) self.assertEqual(outer.empty_tuple, rebuilt_outer.empty_tuple) self.assertEqual(outer.empty_dict, rebuilt_outer.empty_dict) self.assertEqual(outer.enum, rebuilt_outer.enum) self.assertEqual(outer.proto, rebuilt_outer.proto) self.assertEqual(outer.dataclass, rebuilt_outer.dataclass) self.assertEqual(outer.namedtuple, rebuilt_outer.namedtuple) with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES, {rebuilt_outer.symbol_x: 42}): self.assertEqual(symbolic.ToStatic(rebuilt_outer.symbol_2x), 84)
def _ToParamValue(val): """Serializes to HyperparamValue proto.""" param_pb = hyperparams_pb2.HyperparamValue() if isinstance(val, Params): param_pb.param_val.CopyFrom(_ToParam(val)) elif isinstance(val, list) or isinstance(val, range): # The range function is serialized by explicitely calling it. param_pb.list_val.CopyFrom( hyperparams_pb2.HyperparamRepeated()) for v in val: param_pb.list_val.items.extend([_ToParamValue(v)]) elif isinstance(val, tuple): param_pb.tuple_val.CopyFrom( hyperparams_pb2.HyperparamRepeated()) for v in val: param_pb.tuple_val.items.extend([_ToParamValue(v)]) elif isinstance(val, dict): param_pb.dict_val.CopyFrom(hyperparams_pb2.Hyperparam()) for k, v in val.items(): param_pb.dict_val.items[k].CopyFrom(_ToParamValue(v)) elif isinstance(val, type): param_pb.type_val = inspect.getmodule( val).__name__ + '/' + val.__name__ elif isinstance(val, tf.DType): param_pb.dtype_val = val.name elif isinstance(val, str): param_pb.string_val = val elif isinstance(val, bool): param_pb.bool_val = val elif isinstance(val, six.integer_types): param_pb.int_val = val elif isinstance(val, float): param_pb.float_val = val elif isinstance(val, message.Message): param_pb.proto_val.CopyFrom(hyperparams_pb2.ProtoVal()) proto_cls = type(val) param_pb.proto_val.type = inspect.getmodule( proto_cls).__name__ + '/' + proto_cls.__name__ param_pb.proto_val.val = val.SerializeToString() elif val is None: # We represent a NoneType by the absence of any of the oneof. pass else: raise AttributeError('Unsupported type: %s' % type(val)) return param_pb
def testToText(self): outer = hyperparams.Params() outer.Define('foo', 1, '') inner = hyperparams.Params() inner.Define('bar', 2.71, '') inner.Define('baz', 'hello', '') outer.Define('inner', inner, '') outer.Define('tau', False, '') outer.Define('dtype', tf.float32, '') outer.Define('dtype2', tf.int32, '') outer.Define('seqlen', [10, inner, 30], '') outer.Define('tuple', (1, None), '') outer.Define('list_of_params', [inner.Copy()], '') outer.Define('class', TestClass1, '') outer.Define('plain_dict', {'a': 10}, '') outer.Define('complex_dict', {'a': 10, 'b': inner}, '') outer.Define('complex_dict_escape', {'a': 'abc"\'\ndef'}, '') outer.Define('some_class', complex(0, 1), '') outer.Define('optional_bool', None, '') outer.Define('enum', TestEnum.B, '') outer.Define('dataclass', TestDataClass(a=[42], b=tf.float32), '') outer.Define('namedtuple', TestNamedTuple([42], tf.float32), '') outer.Define('namedtuple2', tf.io.FixedLenSequenceFeature([42], tf.float32), '') # Arbitrarily use HyperparameterValue as some example proto. outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '') self.assertEqual( '\n' + outer.ToText(), r""" class : type/__main__/TestClass1 complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'hello'}} complex_dict_escape : {'a': 'abc"\'\ndef'} dataclass : {'a': [42], 'b': 'float32'} dtype : float32 dtype2 : int32 enum : TestEnum.B foo : 1 inner.bar : 2.71 inner.baz : 'hello' list_of_params[0].bar : 2.71 list_of_params[0].baz : 'hello' namedtuple : {'a': [42], 'b': 'float32'} namedtuple2 : {'allow_missing': False, 'default_value': 'NoneType', 'dtype': 'float32', 'shape': [42]} optional_bool : NoneType plain_dict : {'a': 10} proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/int_val: 42 seqlen : [10, {'bar': 2.71, 'baz': 'hello'}, 30] some_class : complex tau : False tuple : (1, 'NoneType') """) outer.FromText(""" dataclass : {'a': 27, 'b': 'int32'} dtype2 : float32 inner.baz : 'world' # foo : 123 optional_bool : true list_of_params[0].bar : 2.72 seqlen : [1, 2.0, '3', [4]] plain_dict : {'x': 0.3} class : type/__main__/TestClass2 tau : true tuple : (2, 3) enum : TestEnum.A # Note dtypes and other non-POD are represented as strings. namedtuple : {'a': 27, 'b': 'int32'} namedtuple2 : {'allow_missing': True, 'default_value': 'NoneType', 'dtype': 'int32', 'shape': [43]} proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b" """) # Note that the 'hello' has turned into 'world'! self.assertEqual( '\n' + outer.ToText(), r""" class : type/__main__/TestClass2 complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'world'}} complex_dict_escape : {'a': 'abc"\'\ndef'} dataclass : {'a': 27, 'b': 'int32'} dtype : float32 dtype2 : float32 enum : TestEnum.A foo : 1 inner.bar : 2.71 inner.baz : 'world' list_of_params[0].bar : 2.72 list_of_params[0].baz : 'hello' namedtuple : {'a': 27, 'b': 'int32'} namedtuple2 : {'allow_missing': True, 'default_value': 'NoneType', 'dtype': 'int32', 'shape': [43]} optional_bool : True plain_dict : {'x': 0.3} proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b" seqlen : [1, 2.0, '3', [4]] some_class : complex tau : True tuple : (2, 3) """) self.assertEqual(outer.dataclass.b, tf.int32) self.assertEqual(outer.namedtuple.b, tf.int32) self.assertEqual(outer.namedtuple2.dtype, tf.int32) self.assertIsNone(outer.namedtuple2.default_value, tf.int32)
def testToText(self): outer = _params.Params() outer.Define('foo', 1, '') inner = _params.Params() inner.Define('bar', 2.71, '') inner.Define('baz', 'hello', '') outer.Define('inner', inner, '') outer.Define('tau', False, '') outer.Define('dtype', tf.float32, '') outer.Define('dtype2', tf.int32, '') outer.Define('seqlen', [10, inner, 30], '') outer.Define('tuple', (1, None), '') outer.Define('list_of_params', [inner.Copy()], '') outer.Define('class', TestClass1, '') outer.Define('plain_dict', {'a': 10}, '') outer.Define('complex_dict', {'a': 10, 'b': inner}, '') outer.Define('complex_dict_escape', {'a': 'abc"\'\ndef'}, '') outer.Define('some_class', complex(0, 1), '') outer.Define('optional_bool', None, '') outer.Define('enum', TestEnum.B, '') # Arbitrarily use HyperparameterValue as some example proto. outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '') self.assertEqual( '\n' + outer.ToText(), r""" class : type/__main__/TestClass1 complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'hello'}} complex_dict_escape : {'a': 'abc"\'\ndef'} dtype : float32 dtype2 : int32 enum : TestEnum.B foo : 1 inner.bar : 2.71 inner.baz : 'hello' list_of_params[0].bar : 2.71 list_of_params[0].baz : 'hello' optional_bool : NoneType plain_dict : {'a': 10} proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/int_val: 42 seqlen : [10, {'bar': 2.71, 'baz': 'hello'}, 30] some_class : complex tau : False tuple : (1, 'NoneType') """) outer.FromText(""" dtype2 : float32 inner.baz : 'world' # foo : 123 optional_bool : true list_of_params[0].bar : 2.72 seqlen : [1, 2.0, '3', [4]] plain_dict : {'x': 0.3} class : type/__main__/TestClass2 tau : true tuple : (2, 3) enum : TestEnum.A proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b" """) # Note that the 'hello' has turned into 'world'! self.assertEqual( '\n' + outer.ToText(), r""" class : type/__main__/TestClass2 complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'world'}} complex_dict_escape : {'a': 'abc"\'\ndef'} dtype : float32 dtype2 : float32 enum : TestEnum.A foo : 1 inner.bar : 2.71 inner.baz : 'world' list_of_params[0].bar : 2.72 list_of_params[0].baz : 'hello' optional_bool : True plain_dict : {'x': 0.3} proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b" seqlen : [1, 2.0, '3', [4]] some_class : complex tau : True tuple : (2, 3) """)