Esempio n. 1
0
 def testRegistry(self):
   self.assertEqual("tf.TwoCompositesSpec",
                    type_spec.get_name(TwoCompositesSpec))
   self.assertEqual("tf.TwoTensorsSpec", type_spec.get_name(TwoTensorsSpec))
   self.assertEqual(TwoCompositesSpec,
                    type_spec.lookup("tf.TwoCompositesSpec"))
   self.assertEqual(TwoTensorsSpec, type_spec.lookup("tf.TwoTensorsSpec"))
Esempio n. 2
0
    def testRegistryGetNameErrors(self):
        with self.assertRaises(TypeError):
            type_spec.get_name(None)

        class Foo(TwoCompositesSpec):
            pass

        with self.assertRaisesRegex(
                ValueError, "TypeSpec __main__.Foo has not been registered."):
            type_spec.get_name(Foo)
Esempio n. 3
0
    def can_encode(self, pyobj):
        """Returns true if `pyboj` can be encoded as a TypeSpec."""
        # pylint: disable=unidiomatic-typecheck
        if type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO:
            return True

        # Check if it's a registered type.
        if isinstance(pyobj, type_spec.TypeSpec):
            try:
                type_spec.get_name(type(pyobj))
                return True
            except ValueError:
                return False
  def do_encode(self, type_spec_value, encode_fn):
    """Returns an encoded proto for the given `tf.TypeSpec`."""
    type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value))
    type_spec_class_name = type(type_spec_value).__name__

    if type_spec_class is None:
      type_spec_class_name = type_spec.get_name(type(type_spec_value))
      if isinstance(type_spec_value, extension_type.ExtensionTypeSpec):
        type_spec_class = struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC
      else:
        type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC
        # Support for saving registered TypeSpecs is currently experimental.
        # Issue a warning to indicate the limitations.
        warnings.warn("Encoding a StructuredValue with type %s; loading this "
                      "StructuredValue will require that this type be "
                      "imported and registered." % type_spec_class_name)

    type_state = type_spec_value._serialize()  # pylint: disable=protected-access
    num_flat_components = len(
        nest.flatten(type_spec_value._component_specs, expand_composites=True))  # pylint: disable=protected-access
    encoded_type_spec = struct_pb2.StructuredValue()
    encoded_type_spec.type_spec_value.CopyFrom(
        struct_pb2.TypeSpecProto(
            type_spec_class=type_spec_class,
            type_state=encode_fn(type_state),
            type_spec_class_name=type_spec_class_name,
            num_flat_components=num_flat_components))
    return encoded_type_spec
Esempio n. 5
0
def get_json_type(obj):
    """Serializes any object to a JSON-serializable structure.

  Args:
      obj: the object to serialize

  Returns:
      JSON-serializable structure representing `obj`.

  Raises:
      TypeError: if `obj` cannot be serialized.
  """
    # if obj is a serializable Keras class instance
    # e.g. optimizer, layer
    if hasattr(obj, 'get_config'):
        return {
            'class_name': obj.__class__.__name__,
            'config': obj.get_config()
        }

    # if obj is any numpy type
    if type(obj).__module__ == np.__name__:
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return obj.item()

    # misc functions (e.g. loss function)
    if callable(obj):
        return obj.__name__

    # if obj is a python 'type'
    if type(obj).__name__ == type.__name__:
        return obj.__name__

    if isinstance(obj, tf.compat.v1.Dimension):
        return obj.value

    if isinstance(obj, tf.TensorShape):
        return obj.as_list()

    if isinstance(obj, tf.DType):
        return obj.name

    if isinstance(obj, collections_abc.Mapping):
        return dict(obj)

    if obj is Ellipsis:
        return {'class_name': '__ellipsis__'}

    if isinstance(obj, wrapt.ObjectProxy):
        return obj.__wrapped__

    if isinstance(obj, tf.TypeSpec):
        try:
            type_spec_name = type_spec.get_name(type(obj))
            return {
                'class_name': 'TypeSpec',
                'type_spec': type_spec_name,
                'serialized': obj._serialize()
            }  # pylint: disable=protected-access
        except ValueError:
            raise ValueError(
                'Unable to serialize {} to JSON, because the TypeSpec '
                'class {} has not been registered.'.format(obj, type(obj)))

    raise TypeError('Not JSON Serializable:', obj)
Esempio n. 6
0
def get_json_type(obj):
    """Serializes any object to a JSON-serializable structure.

  Args:
      obj: the object to serialize

  Returns:
      JSON-serializable structure representing `obj`.

  Raises:
      TypeError: if `obj` cannot be serialized.
  """
    # if obj is a serializable Keras class instance
    # e.g. optimizer, layer
    if hasattr(obj, 'get_config'):
        serialized = generic_utils.serialize_keras_object(obj)
        serialized['__passive_serialization__'] = True
        return serialized

    # if obj is any numpy type
    if type(obj).__module__ == np.__name__:
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return obj.item()

    # misc functions (e.g. loss function)
    if callable(obj):
        return obj.__name__

    # if obj is a python 'type'
    if type(obj).__name__ == type.__name__:
        return obj.__name__

    if isinstance(obj, tf.compat.v1.Dimension):
        return obj.value

    if isinstance(obj, tf.TensorShape):
        return obj.as_list()

    if isinstance(obj, tf.DType):
        return obj.name

    if isinstance(obj, collections.abc.Mapping):
        return dict(obj)

    if obj is Ellipsis:
        return {'class_name': '__ellipsis__'}

    if isinstance(obj, wrapt.ObjectProxy):
        return obj.__wrapped__

    if isinstance(obj, tf.TypeSpec):
        try:
            type_spec_name = type_spec.get_name(type(obj))
            return {
                'class_name': 'TypeSpec',
                'type_spec': type_spec_name,
                'serialized': obj._serialize()
            }  # pylint: disable=protected-access
        except ValueError:
            raise ValueError(
                f'Unable to serialize {obj} to JSON, because the TypeSpec '
                f'class {type(obj)} has not been registered.')
    if isinstance(obj, tf.__internal__.CompositeTensor):
        spec = tf.type_spec_from_value(obj)
        tensors = []
        for tensor in tf.nest.flatten(obj, expand_composites=True):
            tensors.append((tensor.dtype.name, tensor.numpy().tolist()))
        return {
            'class_name': 'CompositeTensor',
            'spec': get_json_type(spec),
            'tensors': tensors
        }

    if isinstance(obj, enum.Enum):
        return obj.value

    raise TypeError(
        f'Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}.')