Beispiel #1
0
  def test_serialization(self):
    # Test policies that are equivalent to a single dtype
    for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
      policy = mp_policy.Policy(policy_name)
      config = mp_policy.serialize(policy)
      self.assertEqual(config, policy_name)
      new_policy = mp_policy.deserialize(config)
      self.assertEqual(str(policy), str(new_policy))

    # Test "_infer" policy
    policy = mp_policy.Policy('_infer')
    config = mp_policy.serialize(policy)
    self.assertIsNone(config)
    new_policy = mp_policy.deserialize(config)
    self.assertEqual(str(policy), str(new_policy))

    class MyPolicy(mp_policy.Policy):
      pass

    # Test policies that are not equivalent to a single dtype
    for policy in (
        mp_policy.Policy('mixed_float16'),
        mp_policy.Policy('mixed_bfloat16'),
        MyPolicy('float32')
    ):
      config = mp_policy.serialize(policy)
      self.assertEqual(config, {'class_name': policy.__class__.__name__,
                                'config': {'name': policy.name}})
      new_policy = mp_policy.deserialize(config,
                                         custom_objects={'MyPolicy': MyPolicy})
      self.assertEqual(str(policy), str(new_policy))
Beispiel #2
0
    def _python_properties_internal(self):
        """Returns dictionary of all python properties."""
        # TODO(kathywu): Add support for metrics serialization.
        # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
        # the python config serialization has caught up.
        metadata = dict(
            class_name=generic_utils.get_registered_name(type(self.obj)),
            name=self.obj.name,
            trainable=self.obj.trainable,
            expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
            dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
            batch_input_shape=getattr(self.obj, '_batch_input_shape', None),
            stateful=self.obj.stateful,
            must_restore_from_config=self.obj._must_restore_from_config,  # pylint: disable=protected-access
        )

        metadata.update(get_config(self.obj))
        if self.obj.input_spec is not None:
            # Layer's input_spec has already been type-checked in the property setter.
            metadata['input_spec'] = tf.nest.map_structure(
                lambda x: generic_utils.serialize_keras_object(x)
                if x else None, self.obj.input_spec)
        if (self.obj.activity_regularizer is not None
                and hasattr(self.obj.activity_regularizer, 'get_config')):
            metadata[
                'activity_regularizer'] = generic_utils.serialize_keras_object(
                    self.obj.activity_regularizer)
        if self.obj._build_input_shape is not None:  # pylint: disable=protected-access
            metadata['build_input_shape'] = self.obj._build_input_shape  # pylint: disable=protected-access
        return metadata
Beispiel #3
0
    def _python_properties_internal(self):
        """Returns dictionary of all python properties."""
        # TODO(kathywu): Add support for metrics serialization.
        # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec)
        # once the python config serialization has caught up.
        metadata = dict(
            name=self.obj.name,
            trainable=self.obj.trainable,
            expects_training_arg=self.obj._expects_training_arg,
            dtype=policy.serialize(self.obj._dtype_policy),
            batch_input_shape=getattr(self.obj, "_batch_input_shape", None),
            stateful=self.obj.stateful,
            must_restore_from_config=self.obj._must_restore_from_config,
            preserve_input_structure_in_config=self.obj.
            _preserve_input_structure_in_config,  # noqa: E501
        )

        metadata.update(get_serialized(self.obj))
        if self.obj.input_spec is not None:
            # Layer's input_spec has already been type-checked in the property
            # setter.
            metadata["input_spec"] = tf.nest.map_structure(
                lambda x: generic_utils.serialize_keras_object(x)
                if x else None,
                self.obj.input_spec,
            )
        if self.obj.activity_regularizer is not None and hasattr(
                self.obj.activity_regularizer, "get_config"):
            metadata[
                "activity_regularizer"] = generic_utils.serialize_keras_object(
                    self.obj.activity_regularizer)
        if self.obj._build_input_shape is not None:
            metadata["build_input_shape"] = self.obj._build_input_shape
        return metadata
Beispiel #4
0
    def test_serialization(self):
        # Test policies that are equivalent to a single dtype
        for policy_name in "float16", "float32", "int8", "string", "bool":
            policy = mp_policy.Policy(policy_name)
            config = mp_policy.serialize(policy)
            self.assertEqual(config, policy_name)
            new_policy = mp_policy.deserialize(config)
            self.assertEqual(str(policy), str(new_policy))

        # Test "_infer" policy
        policy = mp_policy.Policy("_infer")
        config = mp_policy.serialize(policy)
        self.assertIsNone(config)
        new_policy = mp_policy.deserialize(config)
        self.assertEqual(str(policy), str(new_policy))

        class MyPolicy(mp_policy.Policy):
            pass

        # Test policies that are not equivalent to a single dtype
        for policy in (
            mp_policy.Policy("mixed_float16"),
            mp_policy.Policy("mixed_bfloat16"),
            MyPolicy("float32"),
        ):
            config = mp_policy.serialize(policy)
            self.assertEqual(
                config,
                {
                    "class_name": policy.__class__.__name__,
                    "config": {"name": policy.name},
                },
            )
            new_policy = mp_policy.deserialize(
                config, custom_objects={"MyPolicy": MyPolicy}
            )
            self.assertEqual(str(policy), str(new_policy))
Beispiel #5
0
    def test_serialization(self):
        # Test policies that are equivalent to a single dtype
        for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
            policy = mp_policy.Policy(policy_name)
            config = mp_policy.serialize(policy)
            self.assertEqual(config, policy_name)
            new_policy = mp_policy.deserialize(config)
            self.assertEqual(str(policy), str(new_policy))

        # Test "_infer" policy
        policy = mp_policy.Policy('_infer')
        config = mp_policy.serialize(policy)
        self.assertIsNone(config)
        new_policy = mp_policy.deserialize(config)
        self.assertEqual(str(policy), str(new_policy))

        class MyPolicy(mp_policy.Policy):
            pass

        # Test policies that are not equivalent to a single dtype
        for policy in (mp_policy.Policy('mixed_float16'),
                       mp_policy.Policy('mixed_bfloat16'),
                       MyPolicy('float32')):
            config = mp_policy.serialize(policy)
            self.assertEqual(
                config, {
                    'class_name': policy.__class__.__name__,
                    'config': {
                        'name': policy.name
                    }
                })
            new_policy = mp_policy.deserialize(
                config, custom_objects={'MyPolicy': MyPolicy})
            self.assertEqual(str(policy), str(new_policy))

        # Test V1 policies that override the loss scale
        for policy in (
                mp_policy.PolicyV1('float32', loss_scale=2.),
                mp_policy.PolicyV1('float32', loss_scale=None),
                mp_policy.PolicyV1('mixed_float16', loss_scale=2.),
                mp_policy.PolicyV1('mixed_float16', loss_scale=None),
                mp_policy.PolicyV1('mixed_bfloat16', loss_scale=2.),
                mp_policy.PolicyV1('mixed_bfloat16', loss_scale=None),
        ):
            config = mp_policy.serialize(policy)
            expected_loss_scale_config = None
            if policy.loss_scale:
                expected_loss_scale_config = {
                    'class_name': 'FixedLossScale',
                    'config': {
                        'loss_scale_value': 2.
                    }
                }
            self.assertEqual(
                config, {
                    'class_name': policy.__class__.__name__,
                    'config': {
                        'name': policy.name,
                        'loss_scale': expected_loss_scale_config
                    }
                })