Beispiel #1
0
    def testToFromProto(self):
        outer = _params.Params()
        outer.Define('integer_val', 1, '')
        outer.Define('cls_type', type(int), '')
        inner = _params.Params()
        inner.Define('float_val', 2.71, '')
        inner.Define('string_val', 'rosalie et adrien', '')
        inner.Define('bool_val', True, '')
        inner.Define('list_of_tuples_of_dicts', [({'string_key': 1729})], '')
        inner.Define('range', range(1, 3), '')
        outer.Define('inner', inner, '')
        outer.Define('empty_list', [], '')
        outer.Define('empty_tuple', (), '')
        outer.Define('empty_dict', {}, '')
        outer.Define('enum', TestEnum.B, '')
        outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '')

        rebuilt_outer = _params.InstantiableParams.FromProto(outer.ToProto())

        self.assertEqual(outer.integer_val, rebuilt_outer.integer_val)
        self.assertEqual(outer.cls_type, rebuilt_outer.cls_type)
        self.assertNear(outer.inner.float_val, rebuilt_outer.inner.float_val,
                        1e-6)
        self.assertEqual(outer.inner.string_val,
                         rebuilt_outer.inner.string_val)
        self.assertEqual(outer.inner.bool_val, rebuilt_outer.inner.bool_val)
        self.assertEqual(outer.inner.list_of_tuples_of_dicts,
                         rebuilt_outer.inner.list_of_tuples_of_dicts)
        self.assertEqual([1, 2], rebuilt_outer.inner.range)  # Rebuilt as list.
        self.assertEqual(outer.empty_list, rebuilt_outer.empty_list)
        self.assertEqual(outer.empty_tuple, rebuilt_outer.empty_tuple)
        self.assertEqual(outer.empty_dict, rebuilt_outer.empty_dict)
        self.assertEqual(outer.enum, rebuilt_outer.enum)
        self.assertEqual(outer.proto, rebuilt_outer.proto)
Beispiel #2
0
 def _ToParamValue(val):
   """Serializes to HyperparamValue proto."""
   param_pb = hyperparams_pb2.HyperparamValue()
   if isinstance(val, Params):
     param_pb.param_val.CopyFrom(_ToParam(val))
   elif isinstance(val, list):
     for v in val:
       param_pb.list_val.items.append(_ToParamValue(v))
   elif isinstance(val, tuple):
     for v in val:
       param_pb.tuple_val.items.append(_ToParamValue(v))
   elif isinstance(val, dict):
     for k, v in val.items():
       param_pb.dict_val.items[k].CopyFrom(_ToParamValue(v))
   elif isinstance(val, type):
     param_pb.type_val = inspect.getmodule(val).__name__ + '/' + val.__name__
   elif isinstance(val, tf.DType):
     param_pb.dtype_val = val.name
   elif isinstance(val, str):
     param_pb.string_val = val
   elif isinstance(val, bool):
     param_pb.bool_val = val
   elif isinstance(val, six.integer_types):
     param_pb.int_val = val
   elif isinstance(val, float):
     param_pb.float_val = val
   elif val is None:
     # We represent a NoneType by the absence of any of the oneof.
     pass
   else:
     raise AttributeError('Unsupported type: %s' % type(val))
   return param_pb
Beispiel #3
0
 def _ToParamValue(key: str,
                   val: Any) -> hyperparams_pb2.HyperparamValue:
     """Serializes to HyperparamValue proto."""
     param_pb = hyperparams_pb2.HyperparamValue()
     if isinstance(val, Params):
         param_pb.param_val.CopyFrom(_ToParam(val, prefix=key))
     elif isinstance(val, list) or isinstance(val, range):
         # The range function is serialized by explicitely calling it.
         param_pb.list_val.items.extend([
             _ToParamValue(f'{key}[{i}]', v) for i, v in enumerate(val)
         ])
     elif isinstance(val, tuple):
         param_pb.tuple_val.items.extend([
             _ToParamValue(f'{key}[{i}]', v) for i, v in enumerate(val)
         ])
     elif dataclasses.is_dataclass(val) or _IsNamedTuple(val):
         val_cls = type(val)
         items = val.__dict__.items() if dataclasses.is_dataclass(
             val) else val._asdict().items()
         param_pb.named_tuple_val.type = inspect.getmodule(
             val_cls).__name__ + '/' + val_cls.__name__
         param_pb.named_tuple_val.items.extend(
             [_ToParamValue(f'{key}[{k}]', v) for k, v in items])
     # Only dicts where all keys are str can be stored as dict_val.
     elif isinstance(val, dict) and all(
             isinstance(k, str) for k in val):
         param_pb.dict_val.SetInParent()
         for k, v in val.items():
             param_pb.dict_val.items[k].CopyFrom(
                 _ToParamValue(f'{key}[{k}]', v))
     elif isinstance(val, type):
         param_pb.type_val = inspect.getmodule(
             val).__name__ + '/' + val.__name__
     elif isinstance(val, tf.DType):
         param_pb.dtype_val = val.name
     elif isinstance(val, str):
         param_pb.string_val = val
     elif isinstance(val, bool):
         param_pb.bool_val = val
     elif isinstance(val, int):
         param_pb.int_val = val
     elif isinstance(val, float):
         param_pb.float_val = val
     elif isinstance(val, enum.Enum):
         enum_cls = type(val)
         param_pb.enum_val.type = inspect.getmodule(
             enum_cls).__name__ + '/' + enum_cls.__name__
         param_pb.enum_val.name = val.name
     elif isinstance(val, message.Message):
         proto_cls = type(val)
         param_pb.proto_val.type = inspect.getmodule(
             proto_cls).__name__ + '/' + proto_cls.__name__
         param_pb.proto_val.val = val.SerializeToString()
     elif val is None:
         # We represent a NoneType by the absence of any of the oneof.
         pass
     else:
         param_pb.string_repr_val = repr(val)
     return param_pb
Beispiel #4
0
 def _ToParamValue(val):
     """Serializes to HyperparamValue proto."""
     param_pb = hyperparams_pb2.HyperparamValue()
     if isinstance(val, Params):
         param_pb.param_val.CopyFrom(_ToParam(val))
     elif isinstance(val, list) or isinstance(val, range):
         # The range function is serialized by explicitely calling it.
         param_pb.list_val.items.extend([_ToParamValue(v) for v in val])
     elif dataclasses.is_dataclass(val) or _IsNamedTuple(val):
         val_cls = type(val)
         vals = val.__dict__.values() if dataclasses.is_dataclass(
             val) else val._asdict().values()
         param_pb.named_tuple_val.type = inspect.getmodule(
             val_cls).__name__ + '/' + val_cls.__name__
         param_pb.named_tuple_val.items.extend(
             [_ToParamValue(v) for v in vals])
     elif isinstance(val, tuple):
         param_pb.tuple_val.items.extend(
             [_ToParamValue(v) for v in val])
     elif isinstance(val, dict):
         param_pb.dict_val.SetInParent()
         for k, v in val.items():
             param_pb.dict_val.items[k].CopyFrom(_ToParamValue(v))
     elif isinstance(val, type):
         param_pb.type_val = inspect.getmodule(
             val).__name__ + '/' + val.__name__
     elif isinstance(val, tf.DType):
         param_pb.dtype_val = val.name
     elif isinstance(val, str):
         param_pb.string_val = val
     elif isinstance(val, bool):
         param_pb.bool_val = val
     elif isinstance(val, int):
         param_pb.int_val = val
     elif isinstance(val, float):
         param_pb.float_val = val
     elif isinstance(val, enum.Enum):
         enum_cls = type(val)
         param_pb.enum_val.type = inspect.getmodule(
             enum_cls).__name__ + '/' + enum_cls.__name__
         param_pb.enum_val.name = val.name
     elif isinstance(val, message.Message):
         proto_cls = type(val)
         param_pb.proto_val.type = inspect.getmodule(
             proto_cls).__name__ + '/' + proto_cls.__name__
         param_pb.proto_val.val = val.SerializeToString()
     elif val is None:
         # We represent a NoneType by the absence of any of the oneof.
         pass
     else:
         raise AttributeError('Unsupported type: %s for value %s' %
                              (type(val), val))
     return param_pb
Beispiel #5
0
    def testToFromProto(self):
        outer = hyperparams.Params()
        outer.Define('integer_val', 1, '')
        outer.Define('cls_type', type(int), '')
        inner = hyperparams.Params()
        inner.Define('float_val', 2.71, '')
        inner.Define('string_val', 'rosalie et adrien', '')
        inner.Define('bool_val', True, '')
        inner.Define('list_of_tuples_of_dicts', [({'string_key': 1729})], '')
        inner.Define('range', range(1, 3), '')
        outer.Define('inner', inner, '')
        outer.Define('empty_list', [], '')
        outer.Define('empty_tuple', (), '')
        outer.Define('empty_dict', {}, '')
        outer.Define('enum', TestEnum.B, '')
        outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '')
        outer.Define('dataclass', TestDataClass(a=[42], b=tf.float32), '')
        outer.Define('namedtuple',
                     tf.io.FixedLenSequenceFeature([42], tf.float32), '')
        outer.Define('symbol_x', symbolic.Symbol('x'), '')
        outer.Define('symbol_2x', outer.symbol_x * 2, '')

        rebuilt_outer = hyperparams.InstantiableParams.FromProto(
            outer.ToProto())

        self.assertNotIn('cls', rebuilt_outer)
        self.assertEqual(outer.integer_val, rebuilt_outer.integer_val)
        self.assertEqual(outer.cls_type, rebuilt_outer.cls_type)
        self.assertNear(outer.inner.float_val, rebuilt_outer.inner.float_val,
                        1e-6)
        self.assertEqual(outer.inner.string_val,
                         rebuilt_outer.inner.string_val)
        self.assertEqual(outer.inner.bool_val, rebuilt_outer.inner.bool_val)
        self.assertEqual(outer.inner.list_of_tuples_of_dicts,
                         rebuilt_outer.inner.list_of_tuples_of_dicts)
        self.assertEqual([1, 2], rebuilt_outer.inner.range)  # Rebuilt as list.
        self.assertEqual(outer.empty_list, rebuilt_outer.empty_list)
        self.assertEqual(outer.empty_tuple, rebuilt_outer.empty_tuple)
        self.assertEqual(outer.empty_dict, rebuilt_outer.empty_dict)
        self.assertEqual(outer.enum, rebuilt_outer.enum)
        self.assertEqual(outer.proto, rebuilt_outer.proto)
        self.assertEqual(outer.dataclass, rebuilt_outer.dataclass)
        self.assertEqual(outer.namedtuple, rebuilt_outer.namedtuple)

        with symbolic.SymbolToValueMap(symbolic.STATIC_VALUES,
                                       {rebuilt_outer.symbol_x: 42}):
            self.assertEqual(symbolic.ToStatic(rebuilt_outer.symbol_2x), 84)
Beispiel #6
0
 def _ToParamValue(val):
     """Serializes to HyperparamValue proto."""
     param_pb = hyperparams_pb2.HyperparamValue()
     if isinstance(val, Params):
         param_pb.param_val.CopyFrom(_ToParam(val))
     elif isinstance(val, list) or isinstance(val, range):
         # The range function is serialized by explicitely calling it.
         param_pb.list_val.CopyFrom(
             hyperparams_pb2.HyperparamRepeated())
         for v in val:
             param_pb.list_val.items.extend([_ToParamValue(v)])
     elif isinstance(val, tuple):
         param_pb.tuple_val.CopyFrom(
             hyperparams_pb2.HyperparamRepeated())
         for v in val:
             param_pb.tuple_val.items.extend([_ToParamValue(v)])
     elif isinstance(val, dict):
         param_pb.dict_val.CopyFrom(hyperparams_pb2.Hyperparam())
         for k, v in val.items():
             param_pb.dict_val.items[k].CopyFrom(_ToParamValue(v))
     elif isinstance(val, type):
         param_pb.type_val = inspect.getmodule(
             val).__name__ + '/' + val.__name__
     elif isinstance(val, tf.DType):
         param_pb.dtype_val = val.name
     elif isinstance(val, str):
         param_pb.string_val = val
     elif isinstance(val, bool):
         param_pb.bool_val = val
     elif isinstance(val, six.integer_types):
         param_pb.int_val = val
     elif isinstance(val, float):
         param_pb.float_val = val
     elif isinstance(val, message.Message):
         param_pb.proto_val.CopyFrom(hyperparams_pb2.ProtoVal())
         proto_cls = type(val)
         param_pb.proto_val.type = inspect.getmodule(
             proto_cls).__name__ + '/' + proto_cls.__name__
         param_pb.proto_val.val = val.SerializeToString()
     elif val is None:
         # We represent a NoneType by the absence of any of the oneof.
         pass
     else:
         raise AttributeError('Unsupported type: %s' % type(val))
     return param_pb
Beispiel #7
0
  def testToText(self):
    outer = hyperparams.Params()
    outer.Define('foo', 1, '')
    inner = hyperparams.Params()
    inner.Define('bar', 2.71, '')
    inner.Define('baz', 'hello', '')
    outer.Define('inner', inner, '')
    outer.Define('tau', False, '')
    outer.Define('dtype', tf.float32, '')
    outer.Define('dtype2', tf.int32, '')
    outer.Define('seqlen', [10, inner, 30], '')
    outer.Define('tuple', (1, None), '')
    outer.Define('list_of_params', [inner.Copy()], '')
    outer.Define('class', TestClass1, '')
    outer.Define('plain_dict', {'a': 10}, '')
    outer.Define('complex_dict', {'a': 10, 'b': inner}, '')
    outer.Define('complex_dict_escape', {'a': 'abc"\'\ndef'}, '')
    outer.Define('some_class', complex(0, 1), '')
    outer.Define('optional_bool', None, '')
    outer.Define('enum', TestEnum.B, '')
    outer.Define('dataclass', TestDataClass(a=[42], b=tf.float32), '')
    outer.Define('namedtuple', TestNamedTuple([42], tf.float32), '')
    outer.Define('namedtuple2', tf.io.FixedLenSequenceFeature([42], tf.float32),
                 '')
    # Arbitrarily use HyperparameterValue as some example proto.
    outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '')

    self.assertEqual(
        '\n' + outer.ToText(), r"""
class : type/__main__/TestClass1
complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'hello'}}
complex_dict_escape : {'a': 'abc"\'\ndef'}
dataclass : {'a': [42], 'b': 'float32'}
dtype : float32
dtype2 : int32
enum : TestEnum.B
foo : 1
inner.bar : 2.71
inner.baz : 'hello'
list_of_params[0].bar : 2.71
list_of_params[0].baz : 'hello'
namedtuple : {'a': [42], 'b': 'float32'}
namedtuple2 : {'allow_missing': False, 'default_value': 'NoneType', 'dtype': 'float32', 'shape': [42]}
optional_bool : NoneType
plain_dict : {'a': 10}
proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/int_val: 42
seqlen : [10, {'bar': 2.71, 'baz': 'hello'}, 30]
some_class : complex
tau : False
tuple : (1, 'NoneType')
""")

    outer.FromText("""
        dataclass : {'a': 27, 'b': 'int32'}
        dtype2 : float32
        inner.baz : 'world'
        # foo : 123
        optional_bool : true
        list_of_params[0].bar : 2.72
        seqlen : [1, 2.0, '3', [4]]
        plain_dict : {'x': 0.3}
        class : type/__main__/TestClass2
        tau : true
        tuple : (2, 3)
        enum : TestEnum.A
        # Note dtypes and other non-POD are represented as strings.
        namedtuple : {'a': 27, 'b': 'int32'}
        namedtuple2 : {'allow_missing': True, 'default_value': 'NoneType', 'dtype': 'int32', 'shape': [43]}
        proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b"
        """)

    # Note that the 'hello' has turned into 'world'!
    self.assertEqual(
        '\n' + outer.ToText(), r"""
class : type/__main__/TestClass2
complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'world'}}
complex_dict_escape : {'a': 'abc"\'\ndef'}
dataclass : {'a': 27, 'b': 'int32'}
dtype : float32
dtype2 : float32
enum : TestEnum.A
foo : 1
inner.bar : 2.71
inner.baz : 'world'
list_of_params[0].bar : 2.72
list_of_params[0].baz : 'hello'
namedtuple : {'a': 27, 'b': 'int32'}
namedtuple2 : {'allow_missing': True, 'default_value': 'NoneType', 'dtype': 'int32', 'shape': [43]}
optional_bool : True
plain_dict : {'x': 0.3}
proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b"
seqlen : [1, 2.0, '3', [4]]
some_class : complex
tau : True
tuple : (2, 3)
""")
    self.assertEqual(outer.dataclass.b, tf.int32)
    self.assertEqual(outer.namedtuple.b, tf.int32)
    self.assertEqual(outer.namedtuple2.dtype, tf.int32)
    self.assertIsNone(outer.namedtuple2.default_value, tf.int32)
Beispiel #8
0
    def testToText(self):
        outer = _params.Params()
        outer.Define('foo', 1, '')
        inner = _params.Params()
        inner.Define('bar', 2.71, '')
        inner.Define('baz', 'hello', '')
        outer.Define('inner', inner, '')
        outer.Define('tau', False, '')
        outer.Define('dtype', tf.float32, '')
        outer.Define('dtype2', tf.int32, '')
        outer.Define('seqlen', [10, inner, 30], '')
        outer.Define('tuple', (1, None), '')
        outer.Define('list_of_params', [inner.Copy()], '')
        outer.Define('class', TestClass1, '')
        outer.Define('plain_dict', {'a': 10}, '')
        outer.Define('complex_dict', {'a': 10, 'b': inner}, '')
        outer.Define('complex_dict_escape', {'a': 'abc"\'\ndef'}, '')
        outer.Define('some_class', complex(0, 1), '')
        outer.Define('optional_bool', None, '')
        outer.Define('enum', TestEnum.B, '')
        # Arbitrarily use HyperparameterValue as some example proto.
        outer.Define('proto', hyperparams_pb2.HyperparamValue(int_val=42), '')

        self.assertEqual(
            '\n' + outer.ToText(), r"""
class : type/__main__/TestClass1
complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'hello'}}
complex_dict_escape : {'a': 'abc"\'\ndef'}
dtype : float32
dtype2 : int32
enum : TestEnum.B
foo : 1
inner.bar : 2.71
inner.baz : 'hello'
list_of_params[0].bar : 2.71
list_of_params[0].baz : 'hello'
optional_bool : NoneType
plain_dict : {'a': 10}
proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/int_val: 42
seqlen : [10, {'bar': 2.71, 'baz': 'hello'}, 30]
some_class : complex
tau : False
tuple : (1, 'NoneType')
""")

        outer.FromText("""
        dtype2 : float32
        inner.baz : 'world'
        # foo : 123
        optional_bool : true
        list_of_params[0].bar : 2.72
        seqlen : [1, 2.0, '3', [4]]
        plain_dict : {'x': 0.3}
        class : type/__main__/TestClass2
        tau : true
        tuple : (2, 3)
        enum : TestEnum.A
        proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b"
        """)

        # Note that the 'hello' has turned into 'world'!
        self.assertEqual(
            '\n' + outer.ToText(), r"""
class : type/__main__/TestClass2
complex_dict : {'a': 10, 'b': {'bar': 2.71, 'baz': 'world'}}
complex_dict_escape : {'a': 'abc"\'\ndef'}
dtype : float32
dtype2 : float32
enum : TestEnum.A
foo : 1
inner.bar : 2.71
inner.baz : 'world'
list_of_params[0].bar : 2.72
list_of_params[0].baz : 'hello'
optional_bool : True
plain_dict : {'x': 0.3}
proto : proto/lingvo.core.hyperparams_pb2/HyperparamValue/string_val: "a/b"
seqlen : [1, 2.0, '3', [4]]
some_class : complex
tau : True
tuple : (2, 3)
""")