def recreate_function(saved_function, concrete_functions):
    """Creates a `Function` from a `SavedFunction`.

  Args:
    saved_function: `SavedFunction` proto.
    concrete_functions: map from function name to `ConcreteFunction`.
      As a side effect of this function, the `FunctionSpec` from
      `saved_function` is added to each `ConcreteFunction` in this map.

  Returns:
    A `Function`.
  """
    # TODO(andresp): Construct a `Function` with the cache populated
    # instead of creating a new `Function` backed by a Python layer to
    # glue things together. Current approach is nesting functions deeper for each
    # serialization cycle.
    coder = nested_structure_coder.StructureCoder()

    # Note: handling method functions is tricky since make_decorator does not
    # allows control of "ismethod". Additionally since restored functions do
    # not behave as methods i.e. they always use the same captured tensors
    # independent of the object they are bound to, there is little value on
    # propagating that correctly.
    #
    # Ideally this conversion should happen at serialization time. But since
    # there are SavedModels which have "ismethod" populated and have an extra
    # argument that they expect to be ignored, we do it at deserialization.
    function_spec = _deserialize_function_spec_as_nonmethod(
        saved_function.function_spec, coder)

    def restored_function_body(*args, **kwargs):
        """Calls a restored function or raises an error if no matching function."""
        if not saved_function.concrete_functions:
            raise ValueError(
                "Found zero restored functions for caller function.")
        # This is the format of function.graph.structured_input_signature. At this
        # point, the args and kwargs have already been canonicalized.
        inputs = (args, kwargs)

        # First try to find a concrete function that can be called without input
        # conversions. This allows one to pick a more specific trace in case there
        # was also a more expensive one that supported tensors.
        for allow_conversion in [False, True]:
            for function_name in saved_function.concrete_functions:
                function = concrete_functions[function_name]
                if _concrete_function_callable_with(function, inputs,
                                                    allow_conversion):
                    return _call_concrete_function(function, inputs)

        signature_descriptions = []

        def _pretty_format_positional(positional):
            return "Positional arguments ({} total):\n    * {}".format(
                len(positional), "\n    * ".join(str(a) for a in positional))

        for index, function_name in enumerate(
                saved_function.concrete_functions):
            concrete_function = concrete_functions[function_name]
            positional, keyword = concrete_function.structured_input_signature
            signature_descriptions.append(
                "Option {}:\n  {}\n  Keyword arguments: {}".format(
                    index + 1, _pretty_format_positional(positional), keyword))
        raise ValueError(
            "Could not find matching function to call loaded from the SavedModel. "
            "Got:\n  {}\n  Keyword arguments: {}\n\nExpected "
            "these arguments to match one of the following {} option(s):\n\n{}"
            .format(_pretty_format_positional(args), kwargs,
                    len(saved_function.concrete_functions),
                    "\n\n".join(signature_descriptions)))

    concrete_function_objects = []
    for concrete_function_name in saved_function.concrete_functions:
        concrete_function_objects.append(
            concrete_functions[concrete_function_name])

    for cf in concrete_function_objects:
        cf._set_function_spec(function_spec)  # pylint: disable=protected-access

    restored_function = RestoredFunction(restored_function_body,
                                         restored_function_body.__name__,
                                         function_spec,
                                         concrete_function_objects)

    return tf_decorator.make_decorator(
        restored_function_body,
        restored_function,
        decorator_argspec=function_spec.fullargspec)
Exemple #2
0
 def setUp(self):
   self._coder = nested_structure_coder.StructureCoder()
def recreate_function(saved_function, concrete_functions):
    """Creates a `Function` from a `SavedFunction`.

  Args:
    saved_function: `SavedFunction` proto.
    concrete_functions: map from function name to `ConcreteFunction`.

  Returns:
    A `Function`.
  """
    # TODO(andresp): Construct a `Function` with the cache populated
    # instead of creating a new `Function` backed by a Python layer to
    # glue things together. Current approach is nesting functions deeper for each
    # serialization cycle.

    coder = nested_structure_coder.StructureCoder()
    function_spec = _deserialize_function_spec(saved_function.function_spec,
                                               coder)

    def restored_function_body(*args, **kwargs):
        """Calls a restored function."""
        # TODO(allenl): Functions saved with input_signatures should revive with
        # input_signatures.
        try:
            canonicalized_inputs = function_spec.canonicalize_function_inputs(
                *args, **kwargs)
        except ValueError as e:
            raise ValueError(
                "Cannot canonicalize input args %r and kwargs %r. Error: %r." %
                (args, kwargs, e))

        # First try to find a concrete function that can be called without input
        # conversions. This allows one to pick a more specific trace in case there
        # was also a more expensive one that supported tensors.
        for allow_conversion in [False, True]:
            for function_name in saved_function.concrete_functions:
                function = concrete_functions[function_name]
                if _concrete_function_callable_with(function,
                                                    canonicalized_inputs,
                                                    allow_conversion):
                    return _call_concrete_function(function,
                                                   canonicalized_inputs)

        available_signatures = [
            concrete_functions[function_name].graph.structured_input_signature
            for function_name in saved_function.concrete_functions
        ]
        raise ValueError(
            "Could not find matching function to call for canonicalized inputs %r. "
            "Only existing signatures are %r." %
            (canonicalized_inputs, available_signatures))

    concrete_function_objects = []
    for concrete_function_name in saved_function.concrete_functions:
        concrete_function_objects.append(
            concrete_functions[concrete_function_name])

    restored_function = RestoredFunction(restored_function_body,
                                         restored_function_body.__name__,
                                         function_spec,
                                         concrete_function_objects)

    return tf_decorator.make_decorator(
        restored_function_body,
        restored_function,
        decorator_argspec=function_spec.fullargspec)
Exemple #4
0
 def setUp(self):
     super(NestedStructureTest, self).setUp()
     self._coder = nested_structure_coder.StructureCoder()
Exemple #5
0
    def __init__(self,
                 name: str,
                 sampler: reverb_types.DistributionType,
                 remover: reverb_types.DistributionType,
                 max_size: int,
                 rate_limiter: rate_limiters.RateLimiter,
                 max_times_sampled: int = 0,
                 extensions: Sequence[TableExtensionBase] = (),
                 signature: Optional[reverb_types.SpecNest] = None):
        """Constructor of the Table.

    Args:
      name: Name of the priority table.
      sampler: The strategy to use when selecting samples.
      remover: The strategy to use when selecting which items to remove.
      max_size: The maximum number of items which the replay is allowed to hold.
        When an item is inserted into an already full priority table the
        `remover` is used for selecting which item to remove before proceeding
        with the new insert.
      rate_limiter: Manages the data flow by limiting the sample and insert
        calls.
      max_times_sampled: Maximum number of times an item can be sampled before
        it is deleted. Any value < 1 is ignored and means there is no limit.
      extensions: Optional sequence of extensions used to add extra features to
        the table.
      signature: Optional nested structure containing `tf.TypeSpec` objects,
        describing the storage schema for this table.

    Raises:
      ValueError: If name is empty.
      ValueError: If max_size <= 0.
    """
        if not name:
            raise ValueError('name must be nonempty')
        if max_size <= 0:
            raise ValueError('max_size (%d) must be a positive integer' %
                             max_size)

        # Merge the c++ extensions into a single list.
        internal_extensions = []
        for extension in extensions:
            internal_extensions += extension.build_internal_extensions(name)

        if signature:
            flat_signature = tree.flatten(signature)
            for s in flat_signature:
                if not isinstance(s, tensor_spec.TensorSpec):
                    raise ValueError(f'Unsupported signature spec: {s}')
            signature_proto_str = (nested_structure_coder.StructureCoder(
            ).encode_structure(signature).SerializeToString())
        else:
            signature_proto_str = None

        self.internal_table = pybind.Table(
            name=name,
            sampler=sampler,
            remover=remover,
            max_size=max_size,
            max_times_sampled=max_times_sampled,
            rate_limiter=rate_limiter.internal_limiter,
            extensions=internal_extensions,
            signature=signature_proto_str)
Exemple #6
0
def as_composite(obj):
    """Returns a `CompositeTensor` equivalent to the given object.

  Note that the returned object will have any `Variable`,
  `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it
  closes over converted to tensors at the time this function is called. The
  type of the returned object will be a subclass of both `CompositeTensor` and
  `type(obj)`.  For this reason, one should be careful about using
  `as_composite()`, especially for `tf.Module` objects.

  For example, when the composite tensor is created even as part of a
  `tf.Module`, it "fixes" the values of the `DeferredTensor` and `tf.Variable`
  objects it uses:

  ```python
  class M(tf.Module):
    def __init__(self):
      self._v = tf.Variable(1.)
      self._d = tfp.distributions.Normal(
        tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)
      self._dct = tfp.experimental.as_composite(self._d)

    @tf.function
    def mean(self):
      return self._dct.mean()

  m = M()
  m.mean()
  >>> <tf.Tensor: numpy=2.0>
  m._v.assign(2.)  # Doesn't update the CompositeTensor distribution.
  m.mean()
  >>> <tf.Tensor: numpy=2.0>
  ```

  If, however, the creation of the composite is deferred to a method
  call, then the Variable and DeferredTensor will be properly captured
  and respected by the Module and its `SavedModel` (if it is serialized).

  ```python
  class M(tf.Module):
    def __init__(self):
      self._v = tf.Variable(1.)
      self._d = tfp.distributions.Normal(
        tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)

    @tf.function
    def d(self):
      return tfp.experimental.as_composite(self._d)

  m = M()
  m.d().mean()
  >>> <tf.Tensor: numpy=2.0>
  m._v.assign(2.)
  m.d().mean()
  >>> <tf.Tensor: numpy=3.0>
  ```

  Note: This method is best-effort and based on a heuristic for what the
  tensor parameters are and what the non-tensor parameters are. Things might be
  broken, especially for meta-distributions like `TransformedDistribution` or
  `Independent`. (We try to raise NotImplementedError in such cases.) If you'd
  benefit from better coverage, please file an issue on github or send an email
  to `[email protected]`.

  Args:
    obj: A `tfp.distributions.Distribution`.

  Returns:
    obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`.
  """
    if isinstance(obj, CompositeTensor):
        return obj
    cls = _make_convertible(type(obj))
    kwargs = dict(obj.parameters)

    def mk_err_msg(suffix=''):
        return (
            'Unable to make a CompositeTensor for "{}" of type `{}`. Email '
            '`[email protected]` or file an issue on github if you '
            'would benefit from this working. {}'.format(
                obj, type(obj), suffix))

    try:
        params_event_ndims = obj._params_event_ndims()  # pylint: disable=protected-access
    except NotImplementedError:
        params_event_ndims = {}
    for k in params_event_ndims:
        # Use dtype inference from ctor.
        if k in kwargs and kwargs[k] is not None:
            v = getattr(obj, k, kwargs[k])
            try:
                kwargs[k] = tf.convert_to_tensor(v, name=k)
            except TypeError as e:
                raise NotImplementedError(
                    mk_err_msg(
                        '(Unable to convert dependent entry \'{}\' of object '
                        '\'{}\': {})'.format(k, obj, str(e))))
    for k, v in kwargs.items():
        if isinstance(v, distributions.Distribution):
            kwargs[k] = as_composite(v)
        if tensor_util.is_ref(v):
            try:
                kwargs[k] = tf.convert_to_tensor(v, name=k)
            except TypeError as e:
                raise NotImplementedError(
                    mk_err_msg(
                        '(Unable to convert dependent entry \'{}\' of object '
                        '\'{}\': {})'.format(k, obj, str(e))))
    result = cls(**kwargs)
    struct_coder = nested_structure_coder.StructureCoder()
    try:
        struct_coder.encode_structure(result._type_spec)  # pylint: disable=protected-access
    except nested_structure_coder.NotEncodableError as e:
        raise NotImplementedError(
            mk_err_msg('(Unable to serialize: {})'.format(str(e))))
    return result
Exemple #7
0
def from_proto(spec_proto):
    """Decodes a struct_pb2.StructuredValue proto into a nested spec."""
    signature_encoder = nested_structure_coder.StructureCoder()
    return signature_encoder.decode_proto(spec_proto)
def recreate_function(saved_function, concrete_functions):
    """Creates a `Function` from a `SavedFunction`.

  Args:
    saved_function: `SavedFunction` proto.
    concrete_functions: map from function name to `ConcreteFunction`.

  Returns:
    A `Function`.
  """
    # TODO(andresp): Construct a `Function` with the cache populated
    # instead of creating a new `Function` backed by a Python layer to
    # glue things together. Current approach is nesting functions deeper for each
    # serialization cycle.

    coder = nested_structure_coder.StructureCoder()
    function_spec = _deserialize_function_spec(saved_function.function_spec,
                                               coder)

    def restored_function_body(*args, **kwargs):
        """Calls a restored function."""
        # TODO(allenl): Functions saved with input_signatures should revive with
        # input_signatures.
        try:
            canonicalized_inputs = function_spec.canonicalize_function_inputs(
                *args, **kwargs)
        except ValueError as e:
            raise ValueError(
                "Cannot canonicalize input args %r and kwargs %r. Error: %r." %
                (args, kwargs, e))

        debug_considered_signatures = []
        for concrete_function_name in saved_function.concrete_functions:
            function_obj = concrete_functions[concrete_function_name]
            canonicalized_original_inputs = (
                function_obj.graph.structured_input_signature)
            debug_considered_signatures.append(canonicalized_original_inputs)

            if _inputs_compatible(canonicalized_inputs,
                                  canonicalized_original_inputs):
                flattened_inputs = nest.flatten(canonicalized_inputs)
                filtered_inputs = [
                    t for t in flattened_inputs if _is_tensor(t)
                ]

                result = function_obj._call_flat(filtered_inputs)  # pylint: disable=protected-access
                if isinstance(result, ops.Operation):
                    return None
                return result

        raise AssertionError(
            "Could not find matching function to call for canonicalized inputs %r. "
            "Only existing signatures are %r." %
            (canonicalized_inputs, debug_considered_signatures))

    concrete_function_objects = []
    for concrete_function_name in saved_function.concrete_functions:
        concrete_function_objects.append(
            concrete_functions[concrete_function_name])

    return RestoredFunction(restored_function_body,
                            restored_function_body.__name__, function_spec,
                            concrete_function_objects)
Exemple #9
0
def recreate_function(saved_function, concrete_functions):
    """Creates a `Function` from a `SavedFunction`.

  Args:
    saved_function: `SavedFunction` proto.
    concrete_functions: map from function name to `ConcreteFunction`.

  Returns:
    A `Function`.
  """
    # TODO(andresp): Construct a `Function` with the cache populated
    # instead of creating a new `Function` backed by a Python layer to
    # glue things together. Current approach is nesting functions deeper for each
    # serialization cycle.

    coder = nested_structure_coder.StructureCoder()

    # Note: handling method functions is tricky since make_decorator does not
    # allows control of "ismethod". Additionally since restored functions do
    # not behave as methods i.e. they always use the same captured tensors
    # independent of the object they are bound to, there is little value on
    # propagating that correctly.
    #
    # Ideally this conversion should happen at serialization time. But since
    # there are SavedModels which have "ismethod" populated and have an extra
    # argument that they expect to be ignored, we do it at deserialization.
    function_spec = _deserialize_function_spec_as_nonmethod(
        saved_function.function_spec, coder)

    def restored_function_body(*args, **kwargs):
        """Calls a restored function."""
        # This is the format of function.graph.structured_input_signature. At this
        # point, the args and kwargs have already been canonicalized.
        inputs = (args, kwargs)

        # First try to find a concrete function that can be called without input
        # conversions. This allows one to pick a more specific trace in case there
        # was also a more expensive one that supported tensors.
        for allow_conversion in [False, True]:
            for function_name in saved_function.concrete_functions:
                function = concrete_functions[function_name]
                if _concrete_function_callable_with(function, inputs,
                                                    allow_conversion):
                    return _call_concrete_function(function, inputs)

        available_signatures = [
            concrete_functions[function_name].graph.structured_input_signature
            for function_name in saved_function.concrete_functions
        ]
        raise ValueError(
            "Could not find matching function to call for inputs %r. "
            "Only existing signatures are %r." %
            (inputs, available_signatures))

    concrete_function_objects = []
    for concrete_function_name in saved_function.concrete_functions:
        concrete_function_objects.append(
            concrete_functions[concrete_function_name])

    restored_function = RestoredFunction(restored_function_body,
                                         restored_function_body.__name__,
                                         function_spec,
                                         concrete_function_objects)

    return tf_decorator.make_decorator(
        restored_function_body,
        restored_function,
        decorator_argspec=function_spec.fullargspec)
Exemple #10
0
def as_composite(obj):
  """Returns a `CompositeTensor` equivalent to the given object.

  Note that the returned object will have any `Variable`,
  `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it
  closes over converted to tensors at the time this function is called. The
  type of the returned object will be a subclass of both `CompositeTensor` and
  `type(obj)`.

  Note: This method is best-effort and based on a heuristic for what the
  tensor parameters are and what the non-tensor parameters are. Things might be
  broken, especially for meta-distributions like `TransformedDistribution` or
  `Independent`. (We try to raise NotImplementedError in such cases.) If you'd
  benefit from better coverage, please file an issue on github or send an email
  to `[email protected]`.

  Args:
    obj: A `tfp.distributions.Distribution`.

  Returns:
    obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`.
  """
  if isinstance(obj, CompositeTensor):
    return obj
  cls = _make_convertible(type(obj))
  kwargs = dict(obj.parameters)
  def mk_err_msg(suffix=''):
    return (
        'Unable to make a CompositeTensor for "{}" of type `{}`. Email '
        '`[email protected]` or file an issue on github if you '
        'would benefit from this working. {}'.format(obj, type(obj), suffix))
  try:
    params_event_ndims = obj._params_event_ndims()  # pylint: disable=protected-access
  except NotImplementedError:
    params_event_ndims = {}
  for k in params_event_ndims:
    # Use dtype inference from ctor.
    if k in kwargs and kwargs[k] is not None:
      v = getattr(obj, k, kwargs[k])
      try:
        kwargs[k] = tf.convert_to_tensor(v, name=k)
      except TypeError as e:
        raise NotImplementedError(
            mk_err_msg(
                '(Unable to convert dependent entry \'{}\' of object '
                '\'{}\': {})'.format(k, obj, str(e))))
  for k, v in kwargs.items():
    if isinstance(v, distributions.Distribution):
      kwargs[k] = as_composite(v)
    if tensor_util.is_ref(v):
      try:
        kwargs[k] = tf.convert_to_tensor(v, name=k)
      except TypeError as e:
        raise NotImplementedError(
            mk_err_msg(
                '(Unable to convert dependent entry \'{}\' of object '
                '\'{}\': {})'.format(k, obj, str(e))))
  result = cls(**kwargs)
  struct_coder = nested_structure_coder.StructureCoder()
  try:
    struct_coder.encode_structure(result._type_spec)  # pylint: disable=protected-access
  except nested_structure_coder.NotEncodableError as e:
    raise NotImplementedError(
        mk_err_msg('(Unable to serialize: {})'.format(str(e))))
  return result
Exemple #11
0
def get_input_specs_from_function(func: tf_function.ConcreteFunction):
    arg_specs, _ = func.structured_input_signature
    encoder = nested_structure_coder.StructureCoder()
    arg_specs_proto = encoder.encode_structure(arg_specs)
    return arg_specs_proto.SerializeToString()
Exemple #12
0
def get_output_specs_from_function(func: tf_function.ConcreteFunction):
    output_specs = nest.map_structure(type_spec.type_spec_from_value,
                                      func.structured_outputs)
    encoder = nested_structure_coder.StructureCoder()
    output_specs_proto = encoder.encode_structure(output_specs)
    return output_specs_proto.SerializeToString()