def test_duplicate_registration(self):

    @registration.register_serializable()
    class Duplicate(base.Trackable):
      pass

    dup = Duplicate()
    self.assertEqual(
        registration.get_registered_class_name(dup), "Custom.Duplicate")
    # Registrations with different names are ok.
    registration.register_serializable(package="duplicate")(Duplicate)
    # Registrations are checked in reverse order.
    self.assertEqual(
        registration.get_registered_class_name(dup), "duplicate.Duplicate")
    # Both names should resolve to the same class.
    self.assertIs(
        registration.get_registered_class("Custom.Duplicate"), Duplicate)
    self.assertIs(
        registration.get_registered_class("duplicate.Duplicate"), Duplicate)

    # Registrations of the same name fails
    with self.assertRaisesRegex(ValueError, "already been registered"):
      registration.register_serializable(
          package="testing", name="CustomPackage")(
              Duplicate)
Esempio n. 2
0
  def _recreate(self, proto, node_id, nodes):
    """Creates a Python object from a SavedObject protocol buffer.

    Args:
      proto: a SavedObject proto
      node_id: int, the index of this object in the SavedObjectGraph node list.
      nodes: dict mapping int node_ids -> created objects.

    Returns:
      The recreated object, and the set-attribute function for reconnecting
      the trackable children.
    """
    registered_class = registration.get_registered_class(proto.registered_name)
    if registered_class is None:
      registered_class = _BUILT_IN_REGISTRATIONS.get(proto.WhichOneof("kind"))

    dependencies = {}
    for key, dep_node_id in self._get_node_dependencies(proto).items():
      dependencies[key] = nodes[dep_node_id]

    if registered_class:
      obj = registered_class._deserialize_from_proto(  # pylint: disable=protected-access
          proto=proto.serialized_user_proto,
          object_proto=proto,
          dependencies=dependencies,
          export_dir=self._export_dir,
          asset_file_def=self._asset_file_def)
      return obj, type(obj)._add_trackable_child  # pylint: disable=protected-access
    else:
      return self._recreate_default(proto, node_id, dependencies)
Esempio n. 3
0
  def _recreate(self, proto, node_id):
    """Creates a Python object from a SavedObject protocol buffer.

    Args:
      proto: a SavedObject proto
      node_id: int, the index of this object in the SavedObjectGraph node list.

    Returns:
      The recreated object, and the set-attribute function for reconnecting
      the trackable children.
    """
    registered_class = registration.get_registered_class(proto.registered_name)
    if registered_class:
      obj = registered_class._deserialize_from_proto(  # pylint: disable=protected-access
          proto=proto.serialized_user_proto)
      return obj, type(obj)._add_trackable_child  # pylint: disable=protected-access
    else:
      return self._recreate_default(proto, node_id)
 def test_get_invalid_name(self):
     self.assertIsNone(registration.get_registered_class("invalid name"))
 def test_registration(self, expected_cls, expected_name):
     obj = expected_cls()
     self.assertEqual(registration.get_registered_class_name(obj),
                      expected_name)
     self.assertIs(registration.get_registered_class(expected_name),
                   expected_cls)