def recreate_function(saved_function, concrete_functions): """Creates a `Function` from a `SavedFunction`. Args: saved_function: `SavedFunction` proto. concrete_functions: map from function name to `ConcreteFunction`. As a side effect of this function, the `FunctionSpec` from `saved_function` is added to each `ConcreteFunction` in this map. Returns: A `Function`. """ # TODO(andresp): Construct a `Function` with the cache populated # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each # serialization cycle. coder = nested_structure_coder.StructureCoder() # Note: handling method functions is tricky since make_decorator does not # allows control of "ismethod". Additionally since restored functions do # not behave as methods i.e. they always use the same captured tensors # independent of the object they are bound to, there is little value on # propagating that correctly. # # Ideally this conversion should happen at serialization time. But since # there are SavedModels which have "ismethod" populated and have an extra # argument that they expect to be ignored, we do it at deserialization. function_spec = _deserialize_function_spec_as_nonmethod( saved_function.function_spec, coder) def restored_function_body(*args, **kwargs): """Calls a restored function or raises an error if no matching function.""" if not saved_function.concrete_functions: raise ValueError( "Found zero restored functions for caller function.") # This is the format of function.graph.structured_input_signature. At this # point, the args and kwargs have already been canonicalized. inputs = (args, kwargs) # First try to find a concrete function that can be called without input # conversions. This allows one to pick a more specific trace in case there # was also a more expensive one that supported tensors. for allow_conversion in [False, True]: for function_name in saved_function.concrete_functions: function = concrete_functions[function_name] if _concrete_function_callable_with(function, inputs, allow_conversion): return _call_concrete_function(function, inputs) signature_descriptions = [] def _pretty_format_positional(positional): return "Positional arguments ({} total):\n * {}".format( len(positional), "\n * ".join(str(a) for a in positional)) for index, function_name in enumerate( saved_function.concrete_functions): concrete_function = concrete_functions[function_name] positional, keyword = concrete_function.structured_input_signature signature_descriptions.append( "Option {}:\n {}\n Keyword arguments: {}".format( index + 1, _pretty_format_positional(positional), keyword)) raise ValueError( "Could not find matching function to call loaded from the SavedModel. " "Got:\n {}\n Keyword arguments: {}\n\nExpected " "these arguments to match one of the following {} option(s):\n\n{}" .format(_pretty_format_positional(args), kwargs, len(saved_function.concrete_functions), "\n\n".join(signature_descriptions))) concrete_function_objects = [] for concrete_function_name in saved_function.concrete_functions: concrete_function_objects.append( concrete_functions[concrete_function_name]) for cf in concrete_function_objects: cf._set_function_spec(function_spec) # pylint: disable=protected-access restored_function = RestoredFunction(restored_function_body, restored_function_body.__name__, function_spec, concrete_function_objects) return tf_decorator.make_decorator( restored_function_body, restored_function, decorator_argspec=function_spec.fullargspec)
def setUp(self): self._coder = nested_structure_coder.StructureCoder()
def recreate_function(saved_function, concrete_functions): """Creates a `Function` from a `SavedFunction`. Args: saved_function: `SavedFunction` proto. concrete_functions: map from function name to `ConcreteFunction`. Returns: A `Function`. """ # TODO(andresp): Construct a `Function` with the cache populated # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each # serialization cycle. coder = nested_structure_coder.StructureCoder() function_spec = _deserialize_function_spec(saved_function.function_spec, coder) def restored_function_body(*args, **kwargs): """Calls a restored function.""" # TODO(allenl): Functions saved with input_signatures should revive with # input_signatures. try: canonicalized_inputs = function_spec.canonicalize_function_inputs( *args, **kwargs) except ValueError as e: raise ValueError( "Cannot canonicalize input args %r and kwargs %r. Error: %r." % (args, kwargs, e)) # First try to find a concrete function that can be called without input # conversions. This allows one to pick a more specific trace in case there # was also a more expensive one that supported tensors. for allow_conversion in [False, True]: for function_name in saved_function.concrete_functions: function = concrete_functions[function_name] if _concrete_function_callable_with(function, canonicalized_inputs, allow_conversion): return _call_concrete_function(function, canonicalized_inputs) available_signatures = [ concrete_functions[function_name].graph.structured_input_signature for function_name in saved_function.concrete_functions ] raise ValueError( "Could not find matching function to call for canonicalized inputs %r. " "Only existing signatures are %r." % (canonicalized_inputs, available_signatures)) concrete_function_objects = [] for concrete_function_name in saved_function.concrete_functions: concrete_function_objects.append( concrete_functions[concrete_function_name]) restored_function = RestoredFunction(restored_function_body, restored_function_body.__name__, function_spec, concrete_function_objects) return tf_decorator.make_decorator( restored_function_body, restored_function, decorator_argspec=function_spec.fullargspec)
def setUp(self): super(NestedStructureTest, self).setUp() self._coder = nested_structure_coder.StructureCoder()
def __init__(self, name: str, sampler: reverb_types.DistributionType, remover: reverb_types.DistributionType, max_size: int, rate_limiter: rate_limiters.RateLimiter, max_times_sampled: int = 0, extensions: Sequence[TableExtensionBase] = (), signature: Optional[reverb_types.SpecNest] = None): """Constructor of the Table. Args: name: Name of the priority table. sampler: The strategy to use when selecting samples. remover: The strategy to use when selecting which items to remove. max_size: The maximum number of items which the replay is allowed to hold. When an item is inserted into an already full priority table the `remover` is used for selecting which item to remove before proceeding with the new insert. rate_limiter: Manages the data flow by limiting the sample and insert calls. max_times_sampled: Maximum number of times an item can be sampled before it is deleted. Any value < 1 is ignored and means there is no limit. extensions: Optional sequence of extensions used to add extra features to the table. signature: Optional nested structure containing `tf.TypeSpec` objects, describing the storage schema for this table. Raises: ValueError: If name is empty. ValueError: If max_size <= 0. """ if not name: raise ValueError('name must be nonempty') if max_size <= 0: raise ValueError('max_size (%d) must be a positive integer' % max_size) # Merge the c++ extensions into a single list. internal_extensions = [] for extension in extensions: internal_extensions += extension.build_internal_extensions(name) if signature: flat_signature = tree.flatten(signature) for s in flat_signature: if not isinstance(s, tensor_spec.TensorSpec): raise ValueError(f'Unsupported signature spec: {s}') signature_proto_str = (nested_structure_coder.StructureCoder( ).encode_structure(signature).SerializeToString()) else: signature_proto_str = None self.internal_table = pybind.Table( name=name, sampler=sampler, remover=remover, max_size=max_size, max_times_sampled=max_times_sampled, rate_limiter=rate_limiter.internal_limiter, extensions=internal_extensions, signature=signature_proto_str)
def as_composite(obj): """Returns a `CompositeTensor` equivalent to the given object. Note that the returned object will have any `Variable`, `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it closes over converted to tensors at the time this function is called. The type of the returned object will be a subclass of both `CompositeTensor` and `type(obj)`. For this reason, one should be careful about using `as_composite()`, especially for `tf.Module` objects. For example, when the composite tensor is created even as part of a `tf.Module`, it "fixes" the values of the `DeferredTensor` and `tf.Variable` objects it uses: ```python class M(tf.Module): def __init__(self): self._v = tf.Variable(1.) self._d = tfp.distributions.Normal( tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10) self._dct = tfp.experimental.as_composite(self._d) @tf.function def mean(self): return self._dct.mean() m = M() m.mean() >>> <tf.Tensor: numpy=2.0> m._v.assign(2.) # Doesn't update the CompositeTensor distribution. m.mean() >>> <tf.Tensor: numpy=2.0> ``` If, however, the creation of the composite is deferred to a method call, then the Variable and DeferredTensor will be properly captured and respected by the Module and its `SavedModel` (if it is serialized). ```python class M(tf.Module): def __init__(self): self._v = tf.Variable(1.) self._d = tfp.distributions.Normal( tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10) @tf.function def d(self): return tfp.experimental.as_composite(self._d) m = M() m.d().mean() >>> <tf.Tensor: numpy=2.0> m._v.assign(2.) m.d().mean() >>> <tf.Tensor: numpy=3.0> ``` Note: This method is best-effort and based on a heuristic for what the tensor parameters are and what the non-tensor parameters are. Things might be broken, especially for meta-distributions like `TransformedDistribution` or `Independent`. (We try to raise NotImplementedError in such cases.) If you'd benefit from better coverage, please file an issue on github or send an email to `[email protected]`. Args: obj: A `tfp.distributions.Distribution`. Returns: obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`. """ if isinstance(obj, CompositeTensor): return obj cls = _make_convertible(type(obj)) kwargs = dict(obj.parameters) def mk_err_msg(suffix=''): return ( 'Unable to make a CompositeTensor for "{}" of type `{}`. Email ' '`[email protected]` or file an issue on github if you ' 'would benefit from this working. {}'.format( obj, type(obj), suffix)) try: params_event_ndims = obj._params_event_ndims() # pylint: disable=protected-access except NotImplementedError: params_event_ndims = {} for k in params_event_ndims: # Use dtype inference from ctor. if k in kwargs and kwargs[k] is not None: v = getattr(obj, k, kwargs[k]) try: kwargs[k] = tf.convert_to_tensor(v, name=k) except TypeError as e: raise NotImplementedError( mk_err_msg( '(Unable to convert dependent entry \'{}\' of object ' '\'{}\': {})'.format(k, obj, str(e)))) for k, v in kwargs.items(): if isinstance(v, distributions.Distribution): kwargs[k] = as_composite(v) if tensor_util.is_ref(v): try: kwargs[k] = tf.convert_to_tensor(v, name=k) except TypeError as e: raise NotImplementedError( mk_err_msg( '(Unable to convert dependent entry \'{}\' of object ' '\'{}\': {})'.format(k, obj, str(e)))) result = cls(**kwargs) struct_coder = nested_structure_coder.StructureCoder() try: struct_coder.encode_structure(result._type_spec) # pylint: disable=protected-access except nested_structure_coder.NotEncodableError as e: raise NotImplementedError( mk_err_msg('(Unable to serialize: {})'.format(str(e)))) return result
def from_proto(spec_proto): """Decodes a struct_pb2.StructuredValue proto into a nested spec.""" signature_encoder = nested_structure_coder.StructureCoder() return signature_encoder.decode_proto(spec_proto)
def recreate_function(saved_function, concrete_functions): """Creates a `Function` from a `SavedFunction`. Args: saved_function: `SavedFunction` proto. concrete_functions: map from function name to `ConcreteFunction`. Returns: A `Function`. """ # TODO(andresp): Construct a `Function` with the cache populated # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each # serialization cycle. coder = nested_structure_coder.StructureCoder() function_spec = _deserialize_function_spec(saved_function.function_spec, coder) def restored_function_body(*args, **kwargs): """Calls a restored function.""" # TODO(allenl): Functions saved with input_signatures should revive with # input_signatures. try: canonicalized_inputs = function_spec.canonicalize_function_inputs( *args, **kwargs) except ValueError as e: raise ValueError( "Cannot canonicalize input args %r and kwargs %r. Error: %r." % (args, kwargs, e)) debug_considered_signatures = [] for concrete_function_name in saved_function.concrete_functions: function_obj = concrete_functions[concrete_function_name] canonicalized_original_inputs = ( function_obj.graph.structured_input_signature) debug_considered_signatures.append(canonicalized_original_inputs) if _inputs_compatible(canonicalized_inputs, canonicalized_original_inputs): flattened_inputs = nest.flatten(canonicalized_inputs) filtered_inputs = [ t for t in flattened_inputs if _is_tensor(t) ] result = function_obj._call_flat(filtered_inputs) # pylint: disable=protected-access if isinstance(result, ops.Operation): return None return result raise AssertionError( "Could not find matching function to call for canonicalized inputs %r. " "Only existing signatures are %r." % (canonicalized_inputs, debug_considered_signatures)) concrete_function_objects = [] for concrete_function_name in saved_function.concrete_functions: concrete_function_objects.append( concrete_functions[concrete_function_name]) return RestoredFunction(restored_function_body, restored_function_body.__name__, function_spec, concrete_function_objects)
def recreate_function(saved_function, concrete_functions): """Creates a `Function` from a `SavedFunction`. Args: saved_function: `SavedFunction` proto. concrete_functions: map from function name to `ConcreteFunction`. Returns: A `Function`. """ # TODO(andresp): Construct a `Function` with the cache populated # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each # serialization cycle. coder = nested_structure_coder.StructureCoder() # Note: handling method functions is tricky since make_decorator does not # allows control of "ismethod". Additionally since restored functions do # not behave as methods i.e. they always use the same captured tensors # independent of the object they are bound to, there is little value on # propagating that correctly. # # Ideally this conversion should happen at serialization time. But since # there are SavedModels which have "ismethod" populated and have an extra # argument that they expect to be ignored, we do it at deserialization. function_spec = _deserialize_function_spec_as_nonmethod( saved_function.function_spec, coder) def restored_function_body(*args, **kwargs): """Calls a restored function.""" # This is the format of function.graph.structured_input_signature. At this # point, the args and kwargs have already been canonicalized. inputs = (args, kwargs) # First try to find a concrete function that can be called without input # conversions. This allows one to pick a more specific trace in case there # was also a more expensive one that supported tensors. for allow_conversion in [False, True]: for function_name in saved_function.concrete_functions: function = concrete_functions[function_name] if _concrete_function_callable_with(function, inputs, allow_conversion): return _call_concrete_function(function, inputs) available_signatures = [ concrete_functions[function_name].graph.structured_input_signature for function_name in saved_function.concrete_functions ] raise ValueError( "Could not find matching function to call for inputs %r. " "Only existing signatures are %r." % (inputs, available_signatures)) concrete_function_objects = [] for concrete_function_name in saved_function.concrete_functions: concrete_function_objects.append( concrete_functions[concrete_function_name]) restored_function = RestoredFunction(restored_function_body, restored_function_body.__name__, function_spec, concrete_function_objects) return tf_decorator.make_decorator( restored_function_body, restored_function, decorator_argspec=function_spec.fullargspec)
def as_composite(obj): """Returns a `CompositeTensor` equivalent to the given object. Note that the returned object will have any `Variable`, `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it closes over converted to tensors at the time this function is called. The type of the returned object will be a subclass of both `CompositeTensor` and `type(obj)`. Note: This method is best-effort and based on a heuristic for what the tensor parameters are and what the non-tensor parameters are. Things might be broken, especially for meta-distributions like `TransformedDistribution` or `Independent`. (We try to raise NotImplementedError in such cases.) If you'd benefit from better coverage, please file an issue on github or send an email to `[email protected]`. Args: obj: A `tfp.distributions.Distribution`. Returns: obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`. """ if isinstance(obj, CompositeTensor): return obj cls = _make_convertible(type(obj)) kwargs = dict(obj.parameters) def mk_err_msg(suffix=''): return ( 'Unable to make a CompositeTensor for "{}" of type `{}`. Email ' '`[email protected]` or file an issue on github if you ' 'would benefit from this working. {}'.format(obj, type(obj), suffix)) try: params_event_ndims = obj._params_event_ndims() # pylint: disable=protected-access except NotImplementedError: params_event_ndims = {} for k in params_event_ndims: # Use dtype inference from ctor. if k in kwargs and kwargs[k] is not None: v = getattr(obj, k, kwargs[k]) try: kwargs[k] = tf.convert_to_tensor(v, name=k) except TypeError as e: raise NotImplementedError( mk_err_msg( '(Unable to convert dependent entry \'{}\' of object ' '\'{}\': {})'.format(k, obj, str(e)))) for k, v in kwargs.items(): if isinstance(v, distributions.Distribution): kwargs[k] = as_composite(v) if tensor_util.is_ref(v): try: kwargs[k] = tf.convert_to_tensor(v, name=k) except TypeError as e: raise NotImplementedError( mk_err_msg( '(Unable to convert dependent entry \'{}\' of object ' '\'{}\': {})'.format(k, obj, str(e)))) result = cls(**kwargs) struct_coder = nested_structure_coder.StructureCoder() try: struct_coder.encode_structure(result._type_spec) # pylint: disable=protected-access except nested_structure_coder.NotEncodableError as e: raise NotImplementedError( mk_err_msg('(Unable to serialize: {})'.format(str(e)))) return result
def get_input_specs_from_function(func: tf_function.ConcreteFunction): arg_specs, _ = func.structured_input_signature encoder = nested_structure_coder.StructureCoder() arg_specs_proto = encoder.encode_structure(arg_specs) return arg_specs_proto.SerializeToString()
def get_output_specs_from_function(func: tf_function.ConcreteFunction): output_specs = nest.map_structure(type_spec.type_spec_from_value, func.structured_outputs) encoder = nested_structure_coder.StructureCoder() output_specs_proto = encoder.encode_structure(output_specs) return output_specs_proto.SerializeToString()