def test_duplicate_registration(self):

    @registration.register_serializable()
    class Duplicate(base.Trackable):
      pass

    dup = Duplicate()
    self.assertEqual(
        registration.get_registered_class_name(dup), "Custom.Duplicate")
    # Registrations with different names are ok.
    registration.register_serializable(package="duplicate")(Duplicate)
    # Registrations are checked in reverse order.
    self.assertEqual(
        registration.get_registered_class_name(dup), "duplicate.Duplicate")
    # Both names should resolve to the same class.
    self.assertIs(
        registration.get_registered_class("Custom.Duplicate"), Duplicate)
    self.assertIs(
        registration.get_registered_class("duplicate.Duplicate"), Duplicate)

    # Registrations of the same name fails
    with self.assertRaisesRegex(ValueError, "already been registered"):
      registration.register_serializable(
          package="testing", name="CustomPackage")(
              Duplicate)
  def test_predicate(self):

    class Predicate(base.Trackable):

      def __init__(self, register_this):
        self.register_this = register_this

    registration.register_serializable(
        name="RegisterThisOnlyTrue",
        predicate=lambda x: isinstance(x, Predicate) and x.register_this)(
            Predicate)

    a = Predicate(True)
    b = Predicate(False)
    self.assertEqual(
        registration.get_registered_class_name(a),
        "Custom.RegisterThisOnlyTrue")
    self.assertIsNone(registration.get_registered_class_name(b))

    registration.register_serializable(
        name="RegisterAllPredicate",
        predicate=lambda x: isinstance(x, Predicate))(
            Predicate)

    self.assertEqual(
        registration.get_registered_class_name(a),
        "Custom.RegisterAllPredicate")
    self.assertEqual(
        registration.get_registered_class_name(b),
        "Custom.RegisterAllPredicate")
 def test_register_bad_predicate_fails(self):
     with self.assertRaisesRegex(TypeError, "must be callable"):
         registration.register_serializable(predicate=0)(RegisteredClass)
 def test_register_non_class_fails(self):
     obj = RegisteredClass()
     with self.assertRaisesRegex(TypeError, "must be a class"):
         registration.register_serializable()(obj)