Ejemplo n.º 1
0
  def test_serialization(self):
    # Test policies that are equivalent to a single dtype
    for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
      policy = mp_policy.Policy(policy_name)
      config = mp_policy.serialize(policy)
      self.assertEqual(config, policy_name)
      new_policy = mp_policy.deserialize(config)
      self.assertEqual(str(policy), str(new_policy))

    # Test "_infer" policy
    policy = mp_policy.Policy('_infer')
    config = mp_policy.serialize(policy)
    self.assertIsNone(config)
    new_policy = mp_policy.deserialize(config)
    self.assertEqual(str(policy), str(new_policy))

    class MyPolicy(mp_policy.Policy):
      pass

    # Test policies that are not equivalent to a single dtype
    for policy in (
        mp_policy.Policy('mixed_float16'),
        mp_policy.Policy('mixed_bfloat16'),
        MyPolicy('float32')
    ):
      config = mp_policy.serialize(policy)
      self.assertEqual(config, {'class_name': policy.__class__.__name__,
                                'config': {'name': policy.name}})
      new_policy = mp_policy.deserialize(config,
                                         custom_objects={'MyPolicy': MyPolicy})
      self.assertEqual(str(policy), str(new_policy))
Ejemplo n.º 2
0
    def test_serialization(self):
        # Test policies that are equivalent to a single dtype
        for policy_name in "float16", "float32", "int8", "string", "bool":
            policy = mp_policy.Policy(policy_name)
            config = mp_policy.serialize(policy)
            self.assertEqual(config, policy_name)
            new_policy = mp_policy.deserialize(config)
            self.assertEqual(str(policy), str(new_policy))

        # Test "_infer" policy
        policy = mp_policy.Policy("_infer")
        config = mp_policy.serialize(policy)
        self.assertIsNone(config)
        new_policy = mp_policy.deserialize(config)
        self.assertEqual(str(policy), str(new_policy))

        class MyPolicy(mp_policy.Policy):
            pass

        # Test policies that are not equivalent to a single dtype
        for policy in (
            mp_policy.Policy("mixed_float16"),
            mp_policy.Policy("mixed_bfloat16"),
            MyPolicy("float32"),
        ):
            config = mp_policy.serialize(policy)
            self.assertEqual(
                config,
                {
                    "class_name": policy.__class__.__name__,
                    "config": {"name": policy.name},
                },
            )
            new_policy = mp_policy.deserialize(
                config, custom_objects={"MyPolicy": MyPolicy}
            )
            self.assertEqual(str(policy), str(new_policy))
Ejemplo n.º 3
0
    def test_serialization(self):
        # Test policies that are equivalent to a single dtype
        for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
            policy = mp_policy.Policy(policy_name)
            config = mp_policy.serialize(policy)
            self.assertEqual(config, policy_name)
            new_policy = mp_policy.deserialize(config)
            self.assertEqual(str(policy), str(new_policy))

        # Test "_infer" policy
        policy = mp_policy.Policy('_infer')
        config = mp_policy.serialize(policy)
        self.assertIsNone(config)
        new_policy = mp_policy.deserialize(config)
        self.assertEqual(str(policy), str(new_policy))

        class MyPolicy(mp_policy.Policy):
            pass

        # Test policies that are not equivalent to a single dtype
        for policy in (mp_policy.Policy('mixed_float16'),
                       mp_policy.Policy('mixed_bfloat16'),
                       MyPolicy('float32')):
            config = mp_policy.serialize(policy)
            self.assertEqual(
                config, {
                    'class_name': policy.__class__.__name__,
                    'config': {
                        'name': policy.name
                    }
                })
            new_policy = mp_policy.deserialize(
                config, custom_objects={'MyPolicy': MyPolicy})
            self.assertEqual(str(policy), str(new_policy))

        # Test V1 policies that override the loss scale
        for policy in (
                mp_policy.PolicyV1('float32', loss_scale=2.),
                mp_policy.PolicyV1('float32', loss_scale=None),
                mp_policy.PolicyV1('mixed_float16', loss_scale=2.),
                mp_policy.PolicyV1('mixed_float16', loss_scale=None),
                mp_policy.PolicyV1('mixed_bfloat16', loss_scale=2.),
                mp_policy.PolicyV1('mixed_bfloat16', loss_scale=None),
        ):
            config = mp_policy.serialize(policy)
            expected_loss_scale_config = None
            if policy.loss_scale:
                expected_loss_scale_config = {
                    'class_name': 'FixedLossScale',
                    'config': {
                        'loss_scale_value': 2.
                    }
                }
            self.assertEqual(
                config, {
                    'class_name': policy.__class__.__name__,
                    'config': {
                        'name': policy.name,
                        'loss_scale': expected_loss_scale_config
                    }
                })