def test_duplicate_registration(self): @registration.register_serializable() class Duplicate(base.Trackable): pass dup = Duplicate() self.assertEqual( registration.get_registered_class_name(dup), "Custom.Duplicate") # Registrations with different names are ok. registration.register_serializable(package="duplicate")(Duplicate) # Registrations are checked in reverse order. self.assertEqual( registration.get_registered_class_name(dup), "duplicate.Duplicate") # Both names should resolve to the same class. self.assertIs( registration.get_registered_class("Custom.Duplicate"), Duplicate) self.assertIs( registration.get_registered_class("duplicate.Duplicate"), Duplicate) # Registrations of the same name fails with self.assertRaisesRegex(ValueError, "already been registered"): registration.register_serializable( package="testing", name="CustomPackage")( Duplicate)
def _recreate(self, proto, node_id, nodes): """Creates a Python object from a SavedObject protocol buffer. Args: proto: a SavedObject proto node_id: int, the index of this object in the SavedObjectGraph node list. nodes: dict mapping int node_ids -> created objects. Returns: The recreated object, and the set-attribute function for reconnecting the trackable children. """ registered_class = registration.get_registered_class(proto.registered_name) if registered_class is None: registered_class = _BUILT_IN_REGISTRATIONS.get(proto.WhichOneof("kind")) dependencies = {} for key, dep_node_id in self._get_node_dependencies(proto).items(): dependencies[key] = nodes[dep_node_id] if registered_class: obj = registered_class._deserialize_from_proto( # pylint: disable=protected-access proto=proto.serialized_user_proto, object_proto=proto, dependencies=dependencies, export_dir=self._export_dir, asset_file_def=self._asset_file_def) return obj, type(obj)._add_trackable_child # pylint: disable=protected-access else: return self._recreate_default(proto, node_id, dependencies)
def _recreate(self, proto, node_id): """Creates a Python object from a SavedObject protocol buffer. Args: proto: a SavedObject proto node_id: int, the index of this object in the SavedObjectGraph node list. Returns: The recreated object, and the set-attribute function for reconnecting the trackable children. """ registered_class = registration.get_registered_class(proto.registered_name) if registered_class: obj = registered_class._deserialize_from_proto( # pylint: disable=protected-access proto=proto.serialized_user_proto) return obj, type(obj)._add_trackable_child # pylint: disable=protected-access else: return self._recreate_default(proto, node_id)
def test_get_invalid_name(self): self.assertIsNone(registration.get_registered_class("invalid name"))
def test_registration(self, expected_cls, expected_name): obj = expected_cls() self.assertEqual(registration.get_registered_class_name(obj), expected_name) self.assertIs(registration.get_registered_class(expected_name), expected_cls)