def test_serialization(self): # Test policies that are equivalent to a single dtype for policy_name in 'float16', 'float32', 'int8', 'string', 'bool': policy = mp_policy.Policy(policy_name) config = mp_policy.serialize(policy) self.assertEqual(config, policy_name) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) # Test "_infer" policy policy = mp_policy.Policy('_infer') config = mp_policy.serialize(policy) self.assertIsNone(config) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) class MyPolicy(mp_policy.Policy): pass # Test policies that are not equivalent to a single dtype for policy in ( mp_policy.Policy('mixed_float16'), mp_policy.Policy('mixed_bfloat16'), MyPolicy('float32') ): config = mp_policy.serialize(policy) self.assertEqual(config, {'class_name': policy.__class__.__name__, 'config': {'name': policy.name}}) new_policy = mp_policy.deserialize(config, custom_objects={'MyPolicy': MyPolicy}) self.assertEqual(str(policy), str(new_policy))
def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( class_name=generic_utils.get_registered_name(type(self.obj)), name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None), stateful=self.obj.stateful, must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access ) metadata.update(get_config(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = tf.nest.map_structure( lambda x: generic_utils.serialize_keras_object(x) if x else None, self.obj.input_spec) if (self.obj.activity_regularizer is not None and hasattr(self.obj.activity_regularizer, 'get_config')): metadata[ 'activity_regularizer'] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) if self.obj._build_input_shape is not None: # pylint: disable=protected-access metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata
def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) # once the python config serialization has caught up. metadata = dict( name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, dtype=policy.serialize(self.obj._dtype_policy), batch_input_shape=getattr(self.obj, "_batch_input_shape", None), stateful=self.obj.stateful, must_restore_from_config=self.obj._must_restore_from_config, preserve_input_structure_in_config=self.obj. _preserve_input_structure_in_config, # noqa: E501 ) metadata.update(get_serialized(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property # setter. metadata["input_spec"] = tf.nest.map_structure( lambda x: generic_utils.serialize_keras_object(x) if x else None, self.obj.input_spec, ) if self.obj.activity_regularizer is not None and hasattr( self.obj.activity_regularizer, "get_config"): metadata[ "activity_regularizer"] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) if self.obj._build_input_shape is not None: metadata["build_input_shape"] = self.obj._build_input_shape return metadata
def test_serialization(self): # Test policies that are equivalent to a single dtype for policy_name in "float16", "float32", "int8", "string", "bool": policy = mp_policy.Policy(policy_name) config = mp_policy.serialize(policy) self.assertEqual(config, policy_name) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) # Test "_infer" policy policy = mp_policy.Policy("_infer") config = mp_policy.serialize(policy) self.assertIsNone(config) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) class MyPolicy(mp_policy.Policy): pass # Test policies that are not equivalent to a single dtype for policy in ( mp_policy.Policy("mixed_float16"), mp_policy.Policy("mixed_bfloat16"), MyPolicy("float32"), ): config = mp_policy.serialize(policy) self.assertEqual( config, { "class_name": policy.__class__.__name__, "config": {"name": policy.name}, }, ) new_policy = mp_policy.deserialize( config, custom_objects={"MyPolicy": MyPolicy} ) self.assertEqual(str(policy), str(new_policy))
def test_serialization(self): # Test policies that are equivalent to a single dtype for policy_name in 'float16', 'float32', 'int8', 'string', 'bool': policy = mp_policy.Policy(policy_name) config = mp_policy.serialize(policy) self.assertEqual(config, policy_name) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) # Test "_infer" policy policy = mp_policy.Policy('_infer') config = mp_policy.serialize(policy) self.assertIsNone(config) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) class MyPolicy(mp_policy.Policy): pass # Test policies that are not equivalent to a single dtype for policy in (mp_policy.Policy('mixed_float16'), mp_policy.Policy('mixed_bfloat16'), MyPolicy('float32')): config = mp_policy.serialize(policy) self.assertEqual( config, { 'class_name': policy.__class__.__name__, 'config': { 'name': policy.name } }) new_policy = mp_policy.deserialize( config, custom_objects={'MyPolicy': MyPolicy}) self.assertEqual(str(policy), str(new_policy)) # Test V1 policies that override the loss scale for policy in ( mp_policy.PolicyV1('float32', loss_scale=2.), mp_policy.PolicyV1('float32', loss_scale=None), mp_policy.PolicyV1('mixed_float16', loss_scale=2.), mp_policy.PolicyV1('mixed_float16', loss_scale=None), mp_policy.PolicyV1('mixed_bfloat16', loss_scale=2.), mp_policy.PolicyV1('mixed_bfloat16', loss_scale=None), ): config = mp_policy.serialize(policy) expected_loss_scale_config = None if policy.loss_scale: expected_loss_scale_config = { 'class_name': 'FixedLossScale', 'config': { 'loss_scale_value': 2. } } self.assertEqual( config, { 'class_name': policy.__class__.__name__, 'config': { 'name': policy.name, 'loss_scale': expected_loss_scale_config } })