def _python_properties_internal(self):
        """Returns dictionary of all python properties."""
        # TODO(kathywu): Add support for metrics serialization.
        # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
        # the python config serialization has caught up.
        metadata = dict(
            name=self.obj.name,
            trainable=self.obj.trainable,
            expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
            dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
            batch_input_shape=getattr(self.obj, '_batch_input_shape', None),
            stateful=self.obj.stateful,
            must_restore_from_config=self.obj._must_restore_from_config,  # pylint: disable=protected-access
        )

        metadata.update(get_serialized(self.obj))
        if self.obj.input_spec is not None:
            # Layer's input_spec has already been type-checked in the property setter.
            metadata['input_spec'] = nest.map_structure(
                lambda x: generic_utils.serialize_keras_object(x)
                if x else None, self.obj.input_spec)
        if (self.obj.activity_regularizer is not None
                and hasattr(self.obj.activity_regularizer, 'get_config')):
            metadata[
                'activity_regularizer'] = generic_utils.serialize_keras_object(
                    self.obj.activity_regularizer)
        if self.obj._build_input_shape is not None:  # pylint: disable=protected-access
            metadata['build_input_shape'] = self.obj._build_input_shape  # pylint: disable=protected-access
        return metadata
Exemple #2
0
    def test_serialization(self):
        # Test policies that are equivalent to a single dtype
        for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
            policy = mp_policy.Policy(policy_name)
            config = mp_policy.serialize(policy)
            self.assertEqual(config, policy_name)
            new_policy = mp_policy.deserialize(config)
            self.assertEqual(str(policy), str(new_policy))

        # Test "_infer" policy
        policy = mp_policy.Policy('_infer')
        config = mp_policy.serialize(policy)
        self.assertIsNone(config)
        new_policy = mp_policy.deserialize(config)
        self.assertEqual(str(policy), str(new_policy))

        class MyPolicy(mp_policy.Policy):
            pass

        # Test policies that are not equivalent to a single dtype
        for policy in (mp_policy.Policy('mixed_float16'),
                       mp_policy.Policy('mixed_bfloat16'),
                       MyPolicy('float32')):
            config = mp_policy.serialize(policy)
            self.assertEqual(
                config, {
                    'class_name': policy.__class__.__name__,
                    'config': {
                        'name': policy.name
                    }
                })
            new_policy = mp_policy.deserialize(
                config, custom_objects={'MyPolicy': MyPolicy})
            self.assertEqual(str(policy), str(new_policy))

        # Test V1 policies that override the loss scale
        for policy in (
                mp_policy.PolicyV1('float32', loss_scale=2.),
                mp_policy.PolicyV1('float32', loss_scale=None),
                mp_policy.PolicyV1('mixed_float16', loss_scale=2.),
                mp_policy.PolicyV1('mixed_float16', loss_scale=None),
                mp_policy.PolicyV1('mixed_bfloat16', loss_scale=2.),
                mp_policy.PolicyV1('mixed_bfloat16', loss_scale=None),
        ):
            config = mp_policy.serialize(policy)
            expected_loss_scale_config = None
            if policy.loss_scale:
                expected_loss_scale_config = {
                    'class_name': 'FixedLossScale',
                    'config': {
                        'loss_scale_value': 2.
                    }
                }
            self.assertEqual(
                config, {
                    'class_name': policy.__class__.__name__,
                    'config': {
                        'name': policy.name,
                        'loss_scale': expected_loss_scale_config
                    }
                })