Beispiel #1
0
 def test_are_equivalent_types(self):
     t1 = computation_types.TensorType(tf.int32, [None])
     t2 = computation_types.TensorType(tf.int32, [10])
     t3 = computation_types.TensorType(tf.int32, [10])
     self.assertTrue(type_utils.are_equivalent_types(t1, t1))
     self.assertTrue(type_utils.are_equivalent_types(t2, t3))
     self.assertTrue(type_utils.are_equivalent_types(t3, t2))
     self.assertFalse(type_utils.are_equivalent_types(t1, t2))
     self.assertFalse(type_utils.are_equivalent_types(t2, t1))
Beispiel #2
0
 def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args,
                                      expected_kwargs):
     args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args)
     self.assertEqual(len(args), len(expected_args))
     for idx, arg in enumerate(args):
         self.assertTrue(
             type_utils.are_equivalent_types(
                 arg, computation_types.to_type(expected_args[idx])))
     self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys()))
     for k, v in six.iteritems(kwargs):
         self.assertTrue(
             type_utils.are_equivalent_types(computation_types.to_type(v),
                                             expected_kwargs[k]))
Beispiel #3
0
    def from_proto(cls, computation_proto):
        """Returns an instance of a derived class based on 'computation_proto'.

    Args:
      computation_proto: An instance of pb.Computation.

    Returns:
      An instance of a class that implements 'ComputationBuildingBlock' and
      that contains the deserialized logic from in 'computation_proto'.

    Raises:
      NotImplementedError: if computation_proto contains a kind of computation
        for which deserialization has not been implemented yet.
      ValueError: if deserialization failed due to the argument being invalid.
    """
        py_typecheck.check_type(computation_proto, pb.Computation)
        computation_oneof = computation_proto.WhichOneof('computation')
        deserializer = cls._deserializer_dict.get(computation_oneof)
        if deserializer is not None:
            deserialized = deserializer(computation_proto)
            type_spec = type_serialization.deserialize_type(
                computation_proto.type)
            if not type_utils.are_equivalent_types(deserialized.type_signature,
                                                   type_spec):
                raise ValueError(
                    'The type {} derived from the computation structure does not '
                    'match the type {} declared in its signature'.format(
                        str(deserialized.type_signature), str(type_spec)))
            return deserialized
        else:
            raise NotImplementedError(
                'Deserialization for computations of type {} has not been '
                'implemented yet.'.format(computation_oneof))
Beispiel #4
0
def construct_binary_operator_with_upcast(type_signature, operator):
  """Constructs lambda upcasting its argument and applying `operator`.

  The concept of upcasting is explained further in the docstring for
  `apply_binary_operator_with_upcast`.

  Notice that since we are constructing a function here, e.g. for the body
  of an intrinsic, the function we are constructing must be reducible to
  TensorFlow. Therefore `type_signature` can only have named tuple or tensor
  type elements; that is, we cannot handle federated types here in a generic
  way.

  Args:
    type_signature: Value convertible to `computation_types.NamedTupleType`,
      with two elements, both of the same type or the second able to be upcast
      to the first, as explained in `apply_binary_operator_with_upcast`, and
      both containing only tuples and tensors in their type tree.
    operator: Callable defining the operator.

  Returns:
    A `computation_building_blocks.Lambda` encapsulating a function which
    upcasts the second element of its argument and applies the binary
    operator.
  """
  py_typecheck.check_callable(operator)
  type_signature = computation_types.to_type(type_signature)
  _check_generic_operator_type(type_signature)
  ref_to_arg = computation_building_blocks.Reference('binary_operator_arg',
                                                     type_signature)

  def _pack_into_type(to_pack, type_spec):
    """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
    if isinstance(type_spec, computation_types.NamedTupleType):
      elems = anonymous_tuple.to_elements(type_spec)
      packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type))
                      for elem_name, elem_type in elems]
      return computation_building_blocks.Tuple(packed_elems)
    elif isinstance(type_spec, computation_types.TensorType):
      expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar(
          to_pack.type_signature.dtype, type_spec.shape)
      return computation_building_blocks.Call(expand_fn, to_pack)

  y_ref = computation_building_blocks.Selection(ref_to_arg, index=1)
  first_arg = computation_building_blocks.Selection(ref_to_arg, index=0)

  if type_utils.are_equivalent_types(first_arg.type_signature,
                                     y_ref.type_signature):
    second_arg = y_ref
  else:
    second_arg = _pack_into_type(y_ref, first_arg.type_signature)

  fn = computation_constructing_utils.construct_tensorflow_binary_operator(
      first_arg.type_signature, operator)
  packed = computation_building_blocks.Tuple([first_arg, second_arg])
  operated = computation_building_blocks.Call(fn, packed)
  lambda_encapsulating_op = computation_building_blocks.Lambda(
      ref_to_arg.name, ref_to_arg.type_signature, operated)
  return lambda_encapsulating_op
Beispiel #5
0
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     py_typecheck.check_type(type_spec, computation_types.Type)
     hashable_key = _get_hashable_key(value, type_spec)
     try:
         identifier = self._cache[hashable_key]
     except KeyError:
         identifier = None
     except TypeError as err:
         raise RuntimeError(
             'Failed to perform a hash table lookup with a value of Python '
             'type {} and TFF type {}, and payload {}: {}'.format(
                 py_typecheck.type_string(type(value)), type_spec, value,
                 err))
     if isinstance(identifier, CachedValueIdentifier):
         try:
             cached_value = self._cache[identifier]
         except KeyError:
             cached_value = None
         # If may be that the same payload appeared with a mismatching type spec,
         # which may be a legitimate use case if (as it happens) the payload alone
         # does not uniquely determine the type, so we simply opt not to reuse the
         # cache value and fallback on the regular behavior.
         if (cached_value is not None and type_spec is not None
                 and not type_utils.are_equivalent_types(
                     cached_value.type_signature, type_spec)):
             identifier = None
     else:
         identifier = None
     if identifier is None:
         self._num_values_created = self._num_values_created + 1
         identifier = CachedValueIdentifier(str(self._num_values_created))
         self._cache[hashable_key] = identifier
         target_future = asyncio.ensure_future(
             self._target_executor.create_value(value, type_spec))
         cached_value = None
     if cached_value is None:
         cached_value = CachedValue(identifier, hashable_key, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     try:
         await cached_value.target_future
     except Exception as e:
         # Invalidate the entire cache in the inner executor had an exception.
         # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
         # only the current cache item needs to be invalidated; however this
         # currently only occurs when an inner RemoteExecutor has the backend go
         # down.
         self._cache = {}
         raise e
     # No type check is necessary here; we have either checked
     # `type_utils.are_equivalent_types` or just constructed `target_value`
     # explicitly with `type_spec`.
     return cached_value
Beispiel #6
0
 def test_construct_setattr_named_tuple_type_leaves_type_signature_unchanged(
         self):
     good_type = computation_types.NamedTupleType([('a', tf.int32),
                                                   (None, tf.float32),
                                                   ('b', tf.bool)])
     value_comp = computation_building_blocks.Data('x', tf.int32)
     lam = computation_constructing_utils.construct_named_tuple_setattr_lambda(
         good_type, 'a', value_comp)
     self.assertTrue(
         type_utils.are_equivalent_types(lam.type_signature.parameter,
                                         lam.type_signature.result))
Beispiel #7
0
  def test_federated_setattr_call_leaves_type_signatures_alone(self, placement):
    named_tuple_type = computation_types.NamedTupleType([('a', tf.int32),
                                                         (None, tf.float32),
                                                         ('b', tf.bool)])
    good_type = computation_types.FederatedType(named_tuple_type, placement)
    federated_comp = computation_building_blocks.Data('federated_comp',
                                                      good_type)
    value_comp = computation_building_blocks.Data('x', tf.int32)

    federated_setattr = computation_constructing_utils.construct_federated_setattr_call(
        federated_comp, 'a', value_comp)
    self.assertTrue(
        type_utils.are_equivalent_types(federated_setattr.type_signature,
                                        federated_comp.type_signature))
    def _serialize_deserialize_roundtrip_test(self, type_list):
        """Performs roundtrip serialization/deserialization of computation_types.

    Args:
      type_list: A list of instances of computation_types.Type or things
        convertible to it.
    """
        for t in type_list:
            t1 = computation_types.to_type(t)
            p1 = type_serialization.serialize_type(t1)
            t2 = type_serialization.deserialize_type(p1)
            p2 = type_serialization.serialize_type(t2)
            self.assertEqual(repr(t1), repr(t2))
            self.assertEqual(repr(p1), repr(p2))
            self.assertTrue(type_utils.are_equivalent_types(t1, t2))
Beispiel #9
0
 def __add__(self, other):
     other = to_value(other, None, self._context_stack)
     if not type_utils.are_equivalent_types(self.type_signature,
                                            other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             str(self.type_signature), str(other.type_signature)))
     return ValueImpl(
         computation_building_blocks.Call(
             computation_building_blocks.Intrinsic(
                 intrinsic_defs.GENERIC_PLUS.uri,
                 computation_types.FunctionType(
                     [self.type_signature, self.type_signature],
                     self.type_signature)),
             ValueImpl.get_comp(
                 to_value([self, other], None, self._context_stack))),
         self._context_stack)
Beispiel #10
0
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     py_typecheck.check_type(type_spec, computation_types.Type)
     hashable_key = _get_hashable_key(value, type_spec)
     try:
         identifier = self._cache[hashable_key]
     except KeyError:
         identifier = None
     except TypeError as err:
         raise RuntimeError(
             'Failed to perform a has table lookup with a value of Python '
             'type {} and TFF type {}, and payload {}: {}'.format(
                 py_typecheck.type_string(type(value)), type_spec, value,
                 err))
     if isinstance(identifier, CachedValueIdentifier):
         try:
             cached_value = self._cache[identifier]
         except KeyError:
             cached_value = None
         # If may be that the same payload appeared with a mismatching type spec,
         # which may be a legitimate use case if (as it happens) the payload alone
         # does not uniquely determine the type, so we simply opt not to reuse the
         # cache value and fallback on the regular behavior.
         if type_spec is not None and not type_utils.are_equivalent_types(
                 cached_value.type_signature, type_spec):
             identifier = None
     else:
         identifier = None
     if identifier is None:
         self._num_values_created = self._num_values_created + 1
         identifier = CachedValueIdentifier(str(self._num_values_created))
         self._cache[hashable_key] = identifier
         target_future = asyncio.ensure_future(
             self._target_executor.create_value(value, type_spec))
         cached_value = None
     if cached_value is None:
         cached_value = CachedValue(identifier, hashable_key, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     target_value = await cached_value.target_future
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
Beispiel #11
0
 def __eq__(self, other):
   """Base class equality checks names and values equal."""
   # TODO(b/130890785): Delegate value-checking to
   # `building_blocks.ComputationBuildingBlock`.
   if self is other:
     return True
   if not isinstance(other, BoundVariableTracker):
     return NotImplemented
   if self.name != other.name:
     return False
   if (isinstance(self.value, building_blocks.ComputationBuildingBlock) and
       isinstance(other.value, building_blocks.ComputationBuildingBlock)):
     return (self.value.compact_representation() ==
             other.value.compact_representation() and
             type_utils.are_equivalent_types(self.value.type_signature,
                                             other.value.type_signature))
   return self.value is other.value
Beispiel #12
0
 def federated_secure_sum(self, value, bitwidth):
     """Implements `federated_secure_sum` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                                'value to be summed')
     type_utils.check_is_structure_of_integers(value.type_signature)
     bitwidth = value_impl.to_value(bitwidth, None, self._context_stack)
     value_member_ty = value.type_signature.member
     bitwidth_ty = bitwidth.type_signature
     if not type_utils.are_equivalent_types(value_member_ty, bitwidth_ty):
         raise TypeError(
             'Expected `federated_secure_sum` parameters `value` and `bitwidth` '
             'to have the same structure. Found `value` of `{}` and `bitwidth` of `{}`'
             .format(value_member_ty, bitwidth_ty))
     value = value_impl.ValueImpl.get_comp(value)
     bitwidth = value_impl.ValueImpl.get_comp(bitwidth)
     comp = building_block_factory.create_federated_secure_sum(
         value, bitwidth)
     return value_impl.ValueImpl(comp, self._context_stack)
Beispiel #13
0
def _wrap(func, parameter_type, wrapper_fn):
    """Wrap a given `func` with a given `parameter_type` using `wrapper_fn`.

  This method does not handle the multiple modes of usage as wrapper/decorator,
  as those are handled by ComputationWrapper below. It focused on the simple
  case with a function/defun (always present) and either a valid parameter type
  or an indication that there's no parameter (None).

  The only ambiguity left to resolve is whether `func` should be immediately
  wrapped, or treated as a polymorphic callable to be wrapped upon invocation
  based on actual parameter types. The determination is based on the presence
  or absence of parameters in the declaration of `func`. In order to be
  treated as a concrete no-argument computation, `func` shouldn't declare any
  arguments (even with default values).

  Args:
    func: The function or defun to wrap as a computation.
    parameter_type: The parameter type accepted by the computation, or None if
      there is no parameter.
    wrapper_fn: The Python callable that performs actual wrapping. It must
      accept two arguments, and optional third `name`. The first argument will
      be a Python function that takes either zero parameters if the computation
      is to be a no-parameter computation, or exactly one parameter if the
      computation does have a parameter. The second argument will be either None
      for a no-parameter computation, or the type of the computation's parameter
      (an instance of types.Type) if the computation has one. The third,
      optional parameter `name` is the optional name of the function that is
      being wrapped (only for debugging purposes). The object to be returned by
      this function should be an instance of a ConcreteFunction.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types.

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        func_name = func.__name__
    except AttributeError:
        func_name = None
    argspec = func_utils.get_argspec(func)
    parameter_type = computation_types.to_type(parameter_type)
    if not parameter_type:
        if argspec.args or argspec.varargs or argspec.keywords:
            # There is no TFF type specification, and the function/defun declares
            # parameters. Create a polymorphic template.
            def _wrap_polymorphic(wrapper_fn,
                                  func,
                                  parameter_type,
                                  name=func_name):
                return wrapper_fn(func_utils.wrap_as_zero_or_one_arg_callable(
                    func, parameter_type, unpack=True),
                                  parameter_type,
                                  name=name)

            polymorphic_fn = func_utils.PolymorphicFunction(
                lambda pt: _wrap_polymorphic(wrapper_fn, func, pt))

            # When applying a decorator, the __doc__ attribute with the documentation
            # in triple-quotes is not automatically transferred from the function on
            # which it was applied to the wrapped object, so we must transfer it here
            # explicitly.
            polymorphic_fn.__doc__ = getattr(func, '__doc__', None)
            return polymorphic_fn
    concrete_fn = wrapper_fn(func_utils.wrap_as_zero_or_one_arg_callable(
        func, parameter_type),
                             parameter_type,
                             name=func_name)
    py_typecheck.check_type(concrete_fn, func_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if not type_utils.are_equivalent_types(
            concrete_fn.type_signature.parameter, parameter_type):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(func, '__doc__', None)
    return concrete_fn
Beispiel #14
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since
    # it deals exclusively with eager mode. Incubate here, and potentially move
    # there, once stable.

    if device is not None:
        raise NotImplementedError(
            'Unable to embed TF code on a specific device.')

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    str(type_spec), str(comp_type)))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = graph_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = graph_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                str(len(input_tensor_names)), str(len(args))))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        init_names = [init_op] if init_op else []
        returned_elements = tf.import_graph_def(
            graph_merge.uniquify_shared_names(graph_def),
            input_map=dict(zip(input_tensor_names, args)),
            return_elements=output_tensor_names + init_names)
        if init_names:
            with tf.control_dependencies([returned_elements[-1]]):
                return [tf.identity(x) for x in returned_elements[0:-1]]
        else:
            return returned_elements

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    str(len(param_fns)), str(len(arg_parts))))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)
        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)
    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
Beispiel #15
0
def to_representation_for_type(value, type_spec=None, device=None):
    """Verifies or converts the `value` to an eager objct matching `type_spec`.

  WARNING: This function is only partially implemented. It does not support
  data sets at this point.

  The output of this function is always an eager tensor, eager dataset, a
  representation of a TensorFlow computtion, or a nested structure of those
  that matches `type_spec`, and when `device` has been specified, everything
  is placed on that device on a best-effort basis.

  TensorFlow computations are represented here as zero- or one-argument Python
  callables that accept their entire argument bundle as a single Python object.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    type_spec: An instance of `tff.Type`, can be `None` for values that derive
      from `typed_object.TypedObject`.
    device: The optional device to place the value on (for tensor-level values).

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
  """
    if device is not None:
        py_typecheck.check_type(device, six.string_types)
        with tf.device(device):
            return to_representation_for_type(value,
                                              type_spec=type_spec,
                                              device=None)
    type_spec = computation_types.to_type(type_spec)
    if isinstance(value, typed_object.TypedObject):
        if type_spec is not None:
            if not type_utils.are_equivalent_types(value.type_signature,
                                                   type_spec):
                raise TypeError(
                    'Expected a value of type {}, found {}.'.format(
                        str(type_spec), str(value.type_signature)))
        else:
            type_spec = value.type_signature
    if type_spec is None:
        raise ValueError(
            'Cannot derive an eager representation for a value of an unknown type.'
        )
    if isinstance(value, EagerValue):
        return value.internal_representation
    if isinstance(value, executor_value_base.ExecutorValue):
        raise TypeError(
            'Cannot accept a value embedded within a non-eager executor.')
    if isinstance(value, computation_base.Computation):
        return to_representation_for_type(
            computation_impl.ComputationImpl.get_proto(value), type_spec,
            device)
    if isinstance(value, pb.Computation):
        return embed_tensorflow_computation(value, type_spec, device)
    if isinstance(type_spec, computation_types.TensorType):
        if not isinstance(value, tf.Tensor):
            value = tf.constant(value, type_spec.dtype, type_spec.shape)
        value_type = (computation_types.TensorType(value.dtype.base_dtype,
                                                   value.shape))
        if not type_utils.is_assignable_from(type_spec, value_type):
            raise TypeError(
                'The apparent type {} of a tensor {} does not match the expected '
                'type {}.'.format(str(value_type), str(value), str(type_spec)))
        return value
    elif isinstance(type_spec, computation_types.NamedTupleType):
        type_elem = anonymous_tuple.to_elements(type_spec)
        value_elem = (anonymous_tuple.to_elements(
            anonymous_tuple.from_container(value)))
        result_elem = []
        if len(type_elem) != len(value_elem):
            raise TypeError(
                'Expected a {}-element tuple, found {} elements.'.format(
                    str(len(type_elem)), str(len(value_elem))))
        for (t_name, el_type), (v_name, el_val) in zip(type_elem, value_elem):
            if t_name != v_name:
                raise TypeError(
                    'Mismatching element names in type vs. value: {} vs. {}.'.
                    format(t_name, v_name))
            el_repr = to_representation_for_type(el_val, el_type, device)
            result_elem.append((t_name, el_repr))
        return anonymous_tuple.AnonymousTuple(result_elem)
    elif isinstance(type_spec, computation_types.SequenceType):
        py_typecheck.check_type(value,
                                (tf.data.Dataset, tf.compat.v1.data.Dataset,
                                 tf.compat.v2.data.Dataset))
        element_type = type_utils.tf_dtypes_and_shapes_to_type(
            tf.compat.v1.data.get_output_types(value),
            tf.compat.v1.data.get_output_shapes(value))
        value_type = computation_types.SequenceType(element_type)
        if not type_utils.are_equivalent_types(value_type, type_spec):
            raise TypeError('Expected a value of type {}, found {}.'.format(
                str(type_spec), str(value_type)))
        return value
    else:
        raise TypeError('Unexpected type {}.'.format(str(type_spec)))
Beispiel #16
0
def embed_tensorflow_computation(comp, type_spec=None, device=None):
    """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
    # since it deals exclusively with eager mode. Incubate here, and potentially
    # move there, once stable.

    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    type_spec = computation_types.to_type(type_spec)
    if type_spec is not None:
        if not type_utils.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type
    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)

    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                len(input_tensor_names), len(args)))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def),
                input_map=dict(list(zip(input_tensor_names, args))),
                return_elements=output_tensor_names)

        if device is not None:
            with tf.device(device):
                return _import_fn()
        else:
            return _import_fn()

    signature = []
    param_fns = []
    if param_type is not None:
        for spec in anonymous_tuple.flatten(type_spec.parameter):
            if isinstance(spec, computation_types.TensorType):
                signature.append(tf.TensorSpec(spec.shape, spec.dtype))
                param_fns.append(lambda x: x)
            else:
                py_typecheck.check_type(spec, computation_types.SequenceType)
                signature.append(tf.TensorSpec([], tf.variant))
                param_fns.append(tf.data.experimental.to_variant)

    wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

    result_fns = []
    for spec in anonymous_tuple.flatten(result_type):
        if isinstance(spec, computation_types.TensorType):
            result_fns.append(lambda x: x)
        else:
            py_typecheck.check_type(spec, computation_types.SequenceType)
            structure = type_utils.type_to_tf_structure(spec.element)

            def fn(x, structure=structure):
                return tf.data.experimental.from_variant(x, structure)

            result_fns.append(fn)

    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = anonymous_tuple.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return anonymous_tuple.pack_sequence_as(result_type, result_elements)

    fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(
        arg, p, w)

    if device is not None:
        old_fn_to_return = fn_to_return

        # pylint: disable=function-redefined
        def fn_to_return(x):
            with tf.device(device):
                return old_fn_to_return(x)

        # pylint: enable=function-redefined

    if param_type is not None:
        return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
    else:
        return lambda: fn_to_return(None)
def _wrap(fn, parameter_type, wrapper_fn):
    """Wraps a possibly-polymorphic `fn` in `wrapper_fn`.

  If `parameter_type` is `None` and `fn` takes any arguments (even with default
  values), `fn` is inferred to be polymorphic and won't be passed to
  `wrapper_fn` until invocation time (when concrete parameter types are
  available).

  `wrapper_fn` must accept three positional arguments and one defaulted argument
  `name`:

  * `target_fn`, the Python function to be wrapped.

  * `parameter_type`, the optional type of the computation's
    parameter (an instance of `computation_types.Type`).

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: Optional type of any arguments to `fn`.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        fn_name = fn.__name__
    except AttributeError:
        fn_name = None
    signature = function_utils.get_signature(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type is None and signature.parameters:
        # There is no TFF type specification, and the function/defun declares
        # parameters. Create a polymorphic template.
        def _wrap_polymorphic(wrapper_fn, fn, parameter_type, name=fn_name):
            return wrapper_fn(fn, parameter_type, unpack=True, name=name)

        polymorphic_fn = function_utils.PolymorphicFunction(
            lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt))

        # When applying a decorator, the __doc__ attribute with the documentation
        # in triple-quotes is not automatically transferred from the function on
        # which it was applied to the wrapped object, so we must transfer it here
        # explicitly.
        polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
        return polymorphic_fn

    # Either we have a concrete parameter type, or this is no-arg function.
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
    py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if not type_utils.are_equivalent_types(
            concrete_fn.type_signature.parameter, parameter_type):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(fn, '__doc__', None)
    return concrete_fn