def test_serialization(self): # Test policies that are equivalent to a single dtype for policy_name in 'float16', 'float32', 'int8', 'string', 'bool': policy = mp_policy.Policy(policy_name) config = mp_policy.serialize(policy) self.assertEqual(config, policy_name) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) # Test "_infer" policy policy = mp_policy.Policy('_infer') config = mp_policy.serialize(policy) self.assertIsNone(config) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) class MyPolicy(mp_policy.Policy): pass # Test policies that do not override the loss scale for policy in ( mp_policy.Policy('mixed_float16'), mp_policy.Policy('mixed_bfloat16'), MyPolicy('float32') ): config = mp_policy.serialize(policy) self.assertEqual(config, {'class_name': policy.__class__.__name__, 'config': {'name': policy.name}}) new_policy = mp_policy.deserialize(config, custom_objects={'MyPolicy': MyPolicy}) self.assertEqual(str(policy), str(new_policy)) # Test policies that override the loss scale for policy in ( mp_policy.Policy('float32', loss_scale=2.), mp_policy.Policy('float32', loss_scale=None), mp_policy.Policy('mixed_float16', loss_scale=2.), mp_policy.Policy('mixed_float16', loss_scale=None), mp_policy.Policy('mixed_bfloat16', loss_scale=2.), mp_policy.Policy('mixed_bfloat16', loss_scale=None), ): config = mp_policy.serialize(policy) expected_loss_scale_config = None if policy.loss_scale: expected_loss_scale_config = { 'class_name': 'FixedLossScale', 'config': {'loss_scale_value': 2.} } self.assertEqual( config, { 'class_name': policy.__class__.__name__, 'config': { 'name': policy.name, 'loss_scale': expected_loss_scale_config } }) new_policy = mp_policy.deserialize( config, custom_objects={'MyPolicy': MyPolicy}) self.assertEqual(str(policy), str(new_policy))
def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( class_name=type(self.obj).__name__, name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None), stateful=self.obj.stateful) metadata.update(get_config(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = nest.map_structure( lambda x: generic_utils.serialize_keras_object(x) if x else None, self.obj.input_spec) if (self.obj.activity_regularizer is not None and hasattr(self.obj.activity_regularizer, 'get_config')): metadata[ 'activity_regularizer'] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) if self.obj._build_input_shape is not None: # pylint: disable=protected-access metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata
def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( class_name=type(self.obj).__name__, name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None)) with generic_utils.skip_failed_serialization(): # Store the config dictionary, which may be used when reviving the object. # When loading, the program will attempt to revive the object from config, # and if that fails, the object will be revived from the SavedModel. config = generic_utils.serialize_keras_object(self.obj)['config'] if config is not None: metadata['config'] = config if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = nest.map_structure( lambda x: generic_utils.serialize_keras_object(x) if x else None, self.obj.input_spec) if (self.obj.activity_regularizer is not None and hasattr(self.obj.activity_regularizer, 'get_config')): metadata[ 'activity_regularizer'] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) return metadata
def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( class_name=type(self.obj).__name__, name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None)) try: # Store the config dictionary, which is only used by the revived object # to return the original config when revived_obj.get_config() is called. # It is not important for recreating the revived object. metadata['config'] = self.obj.get_config() except NotImplementedError: # in the case of a subclassed model, the get_config() method will throw # a NotImplementedError. pass if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = nest.map_structure( lambda x: None if x is None else serialize_keras_object(x), self.obj.input_spec) if (self.obj.activity_regularizer is not None and hasattr(self.obj.activity_regularizer, 'get_config')): metadata['activity_regularizer'] = serialize_keras_object( self.obj.activity_regularizer) return metadata