Esempio n. 1
0
 def test_returns(self):
     t1 = computation_types.TensorType(tf.int32, [None])
     t2 = computation_types.TensorType(tf.int32, [10])
     t3 = computation_types.TensorType(tf.int32, [10])
     self.assertTrue(type_analysis.are_equivalent_types(t1, t1))
     self.assertTrue(type_analysis.are_equivalent_types(t2, t3))
     self.assertTrue(type_analysis.are_equivalent_types(t3, t2))
     self.assertFalse(type_analysis.are_equivalent_types(t1, t2))
     self.assertFalse(type_analysis.are_equivalent_types(t2, t1))
Esempio n. 2
0
 def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args,
                                      expected_kwargs):
     args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args)
     self.assertEqual(len(args), len(expected_args))
     for idx, arg in enumerate(args):
         self.assertTrue(
             type_analysis.are_equivalent_types(
                 arg, computation_types.to_type(expected_args[idx])))
     self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys()))
     for k, v in kwargs.items():
         self.assertTrue(
             type_analysis.are_equivalent_types(
                 computation_types.to_type(v), expected_kwargs[k]))
Esempio n. 3
0
def reconcile_value_type_with_type_spec(value_type, type_spec):
    """Reconciles a pair of types.

  Args:
    value_type: An instance of `tff.Type` or something convertible to it. Must
      not be `None`.
    type_spec: An instance of `tff.Type`, something convertible to it, or
      `None`.

  Returns:
    Either `value_type` if `type_spec` is `None`, or `type_spec` if `type_spec`
    is not `None` and rquivalent with `value_type`.

  Raises:
    TypeError: If arguments are of incompatible types.
  """
    value_type = computation_types.to_type(value_type)
    py_typecheck.check_type(value_type, computation_types.Type)
    if type_spec is None:
        return value_type
    else:
        type_spec = computation_types.to_type(type_spec)
        if type_analysis.are_equivalent_types(value_type, type_spec):
            return type_spec
        else:
            raise TypeError('Expected a value of type {}, found {}.'.format(
                type_spec, value_type))
Esempio n. 4
0
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     py_typecheck.check_type(type_spec, computation_types.Type)
     hashable_key = _get_hashable_key(value, type_spec)
     try:
         identifier = self._cache.get(hashable_key)
     except TypeError as err:
         raise RuntimeError(
             'Failed to perform a hash table lookup with a value of Python '
             'type {} and TFF type {}, and payload {}: {}'.format(
                 py_typecheck.type_string(type(value)), type_spec, value,
                 err))
     if isinstance(identifier, CachedValueIdentifier):
         cached_value = self._cache.get(identifier)
         # If may be that the same payload appeared with a mismatching type spec,
         # which may be a legitimate use case if (as it happens) the payload alone
         # does not uniquely determine the type, so we simply opt not to reuse the
         # cache value and fallback on the regular behavior.
         if (cached_value is not None and type_spec is not None
                 and not type_analysis.are_equivalent_types(
                     cached_value.type_signature, type_spec)):
             identifier = None
     else:
         identifier = None
     if identifier is None:
         self._num_values_created = self._num_values_created + 1
         identifier = CachedValueIdentifier(str(self._num_values_created))
         self._cache[hashable_key] = identifier
         target_future = asyncio.ensure_future(
             self._target_executor.create_value(value, type_spec))
         cached_value = None
     if cached_value is None:
         cached_value = CachedValue(identifier, hashable_key, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     try:
         await cached_value.target_future
     except Exception as e:
         # Invalidate the entire cache in the inner executor had an exception.
         # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
         # only the current cache item needs to be invalidated; however this
         # currently only occurs when an inner RemoteExecutor has the backend go
         # down.
         self._cache = {}
         raise e
     # No type check is necessary here; we have either checked
     # `type_analysis.are_equivalent_types` or just constructed `target_value`
     # explicitly with `type_spec`.
     return cached_value
  def _serialize_deserialize_roundtrip_test(self, type_list):
    """Performs roundtrip serialization/deserialization of computation_types.

    Args:
      type_list: A list of instances of computation_types.Type or things
        convertible to it.
    """
    for t in type_list:
      t1 = computation_types.to_type(t)
      p1 = type_serialization.serialize_type(t1)
      t2 = type_serialization.deserialize_type(p1)
      p2 = type_serialization.serialize_type(t2)
      self.assertEqual(repr(t1), repr(t2))
      self.assertEqual(repr(p1), repr(p2))
      self.assertTrue(type_analysis.are_equivalent_types(t1, t2))
Esempio n. 6
0
 def __add__(self, other):
     other = to_value(other, None, self._context_stack)
     if not type_analysis.are_equivalent_types(self.type_signature,
                                               other.type_signature):
         raise TypeError('Cannot add {} and {}.'.format(
             self.type_signature, other.type_signature))
     return ValueImpl(
         building_blocks.Call(
             building_blocks.Intrinsic(
                 intrinsic_defs.GENERIC_PLUS.uri,
                 computation_types.FunctionType(
                     [self.type_signature, self.type_signature],
                     self.type_signature)),
             ValueImpl.get_comp(
                 to_value([self, other], None, self._context_stack))),
         self._context_stack)
Esempio n. 7
0
 def __eq__(self, other):
   """Base class equality checks names and values equal."""
   # TODO(b/130890785): Delegate value-checking to
   # `building_blocks.ComputationBuildingBlock`.
   if self is other:
     return True
   if not isinstance(other, BoundVariableTracker):
     return NotImplemented
   if self.name != other.name:
     return False
   if (isinstance(self.value, building_blocks.ComputationBuildingBlock) and
       isinstance(other.value, building_blocks.ComputationBuildingBlock)):
     return (self.value.compact_representation() ==
             other.value.compact_representation() and
             type_analysis.are_equivalent_types(self.value.type_signature,
                                                other.value.type_signature))
   return self.value is other.value
def embed_tensorflow_computation(comp, type_spec=None, device=None):
  """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional device name.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
  # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
  # since it deals exclusively with eager mode. Incubate here, and potentially
  # move there, once stable.

  py_typecheck.check_type(comp, pb.Computation)
  comp_type = type_serialization.deserialize_type(comp.type)
  type_spec = computation_types.to_type(type_spec)
  if type_spec is not None:
    if not type_analysis.are_equivalent_types(type_spec, comp_type):
      raise TypeError('Expected a computation of type {}, got {}.'.format(
          type_spec, comp_type))
  else:
    type_spec = comp_type
  which_computation = comp.WhichOneof('computation')
  if which_computation != 'tensorflow':
    raise TypeError('Expected a TensorFlow computation, found {}.'.format(
        which_computation))

  if isinstance(type_spec, computation_types.FunctionType):
    param_type = type_spec.parameter
    result_type = type_spec.result
  else:
    param_type = None
    result_type = type_spec

  if param_type is not None:
    input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.parameter)
  else:
    input_tensor_names = []

  output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
      comp.tensorflow.result)

  def function_to_wrap(*args):  # pylint: disable=missing-docstring
    if len(args) != len(input_tensor_names):
      raise RuntimeError('Expected {} arguments, found {}.'.format(
          len(input_tensor_names), len(args)))
    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    if init_op:
      graph_def = tensorflow_utils.add_control_deps_for_init_op(
          graph_def, init_op)

    def _import_fn():
      return tf.import_graph_def(
          graph_merge.uniquify_shared_names(graph_def),
          input_map=dict(list(zip(input_tensor_names, args))),
          return_elements=output_tensor_names)

    if device is not None:
      with tf.device(device):
        return _import_fn()
    else:
      return _import_fn()

  signature = []
  param_fns = []
  if param_type is not None:
    for spec in anonymous_tuple.flatten(type_spec.parameter):
      if isinstance(spec, computation_types.TensorType):
        signature.append(tf.TensorSpec(spec.shape, spec.dtype))
        param_fns.append(lambda x: x)
      else:
        py_typecheck.check_type(spec, computation_types.SequenceType)
        signature.append(tf.TensorSpec([], tf.variant))
        param_fns.append(tf.data.experimental.to_variant)

  wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature)

  result_fns = []
  for spec in anonymous_tuple.flatten(result_type):
    if isinstance(spec, computation_types.TensorType):
      result_fns.append(lambda x: x)
    else:
      py_typecheck.check_type(spec, computation_types.SequenceType)
      structure = type_conversions.type_to_tf_structure(spec.element)

      def fn(x, structure=structure):
        return tf.data.experimental.from_variant(x, structure)

      result_fns.append(fn)

  def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
    param_elements = []
    if arg is not None:
      arg_parts = anonymous_tuple.flatten(arg)
      if len(arg_parts) != len(param_fns):
        raise RuntimeError('Expected {} arguments, found {}.'.format(
            len(param_fns), len(arg_parts)))
      for arg_part, param_fn in zip(arg_parts, param_fns):
        param_elements.append(param_fn(arg_part))
    result_parts = wrapped_fn(*param_elements)

    # There is a tf.wrap_function(...) issue b/144127474 that variables created
    # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
    # destroyed.  So get all the variables from `wrapped_fn` and destroy
    # manually.
    # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
    # is fixed.
    resources = []
    for op in wrapped_fn.graph.get_operations():
      if op.type == 'VarHandleOp':
        resources += op.outputs
    if resources:
      for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
        tf.raw_ops.DestroyResourceOp(resource=resource)

    result_elements = []
    for result_part, result_fn in zip(result_parts, result_fns):
      result_elements.append(result_fn(result_part))
    return anonymous_tuple.pack_sequence_as(result_type, result_elements)

  fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return(arg, p, w)

  if device is not None:
    old_fn_to_return = fn_to_return

    # pylint: disable=function-redefined
    def fn_to_return(x):
      with tf.device(device):
        return old_fn_to_return(x)

    # pylint: enable=function-redefined

  if param_type is not None:
    return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
  else:
    return lambda: fn_to_return(None)
Esempio n. 9
0
def import_tensorflow_computation(comp, type_spec=None, name='fn'):
    """Converts a TF computation into an MLIR module that can be compiled by IREE.

  WARNING: This helper function is under construction, and most capabilities are
  not implemented at this stage:

  * The parameter and result of `comp` can only be a single tensor. Named
    tuples, sequences, or functional types are not currently supported.

  * Only tensorflow code can be imported.

  TODO(b/153499219): Add support for named tuples, sequences, and functions.

  Args:
    comp: An instance of a `pb.Computation` with TensorFlow code to import.
    type_spec: An optional `tff.Type` instance.
    name: An optional `str` name of the (single) function in the IREE module.

  Returns:
    An instance of IREE compiler's `CompilerModule` class with the imported
    function present.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    py_typecheck.check_type(comp, pb.Computation)
    comp_type = type_serialization.deserialize_type(comp.type)
    if type_spec is not None:
        py_typecheck.check_type(type_spec, computation_types.Type)
        if not type_analysis.are_equivalent_types(type_spec, comp_type):
            raise TypeError(
                'Expected a computation of type {}, got {}.'.format(
                    type_spec, comp_type))
    else:
        type_spec = comp_type

    if isinstance(type_spec, computation_types.FunctionType):
        param_type = type_spec.parameter
        result_type = type_spec.result
    else:
        param_type = None
        result_type = type_spec

    # TODO(b/153499219): Replace this with a recursive check of the signature
    # after relaxing the type restrictions and introducing nested structures.
    py_typecheck.check_type(result_type, computation_types.TensorType)
    if param_type is not None:
        py_typecheck.check_type(param_type, computation_types.TensorType)

    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)
    if param_type is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    return_elements = input_tensor_names + output_tensor_names
    if init_op:
        graph_def = tensorflow_utils.add_control_deps_for_init_op(
            graph_def, init_op)
        return_elements.append(init_op)

    with tf.Graph().as_default() as graph:
        # TODO(b/153499219): See if we can reintroduce uniquify_shared_names().
        # Right now, it causes loader breakage, and unclear if still necessary.
        import_results = tf.import_graph_def(graph_def,
                                             input_map={},
                                             return_elements=return_elements,
                                             name='')

    if init_op:
        initializer = import_results[-1]
        import_results.pop()
    else:
        initializer = None

    inputs = import_results[0:len(input_tensor_names)]
    outputs = import_results[len(input_tensor_names):]

    with graph.as_default():
        # TODO(b/153499219): Find a way to reflect the nested parameter and result
        # structure here after relaxing the restrictions.
        if inputs:
            assert len(inputs) < 2
            input_dict = {
                'parameter':
                tf.compat.v1.saved_model.utils.build_tensor_info(inputs[0])
            }
        else:
            input_dict = {}
        assert len(outputs) == 1
        output_dict = {
            'result':
            tf.compat.v1.saved_model.utils.build_tensor_info(outputs[0])
        }
        sig_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs=input_dict, outputs=output_dict, method_name=name)
        with tempfile.TemporaryDirectory() as model_dir:
            builder = tf.compat.v1.saved_model.Builder(model_dir)
            with tf.compat.v1.Session(graph=graph) as sess:
                builder.add_meta_graph_and_variables(
                    sess, ['unused'],
                    signature_def_map={name: sig_def},
                    legacy_init_op=initializer,
                    strip_default_attrs=True)
                builder.save()
            return iree_compiler.tf_load_signature_def_saved_model(
                model_dir, tags=set(['unused']), exported_names=[name])
Esempio n. 10
0
def _wrap(fn, parameter_type, wrapper_fn):
    """Wraps a possibly-polymorphic `fn` in `wrapper_fn`.

  If `parameter_type` is `None` and `fn` takes any arguments (even with default
  values), `fn` is inferred to be polymorphic and won't be passed to
  `wrapper_fn` until invocation time (when concrete parameter types are
  available).

  `wrapper_fn` must accept three positional arguments and one defaulted argument
  `name`:

  * `target_fn`, the Python function to be wrapped.

  * `parameter_type`, the optional type of the computation's
    parameter (an instance of `computation_types.Type`).

  * `unpack`, an argument which will be passed on to
    `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`.
    See that function for details.

  * Optional `name`, the name of the function that is being wrapped (only for
    debugging purposes).

  Args:
    fn: The function or defun to wrap as a computation.
    parameter_type: Optional type of any arguments to `fn`.
    wrapper_fn: The Python callable that performs actual wrapping. The object to
      be returned by this function should be an instance of a
      `ConcreteFunction`.

  Returns:
    Either the result of wrapping (an object that represents the computation),
    or a polymorphic callable that performs wrapping upon invocation based on
    argument types. The returned function still may accept multiple
    arguments (it has not yet had
    `function_uils.wrap_as_zero_or_one_arg_callable` applied to it).

  Raises:
    TypeError: if the arguments are of the wrong types, or the `wrapper_fn`
      constructs something that isn't a ConcreteFunction.
  """
    try:
        fn_name = fn.__name__
    except AttributeError:
        fn_name = None
    signature = function_utils.get_signature(fn)
    parameter_type = computation_types.to_type(parameter_type)
    if parameter_type is None and signature.parameters:
        # There is no TFF type specification, and the function/defun declares
        # parameters. Create a polymorphic template.
        def _wrap_polymorphic(parameter_type: computation_types.Type,
                              unpack: Optional[bool]):
            return wrapper_fn(fn, parameter_type, unpack=unpack, name=fn_name)

        polymorphic_fn = function_utils.PolymorphicFunction(_wrap_polymorphic)

        # When applying a decorator, the __doc__ attribute with the documentation
        # in triple-quotes is not automatically transferred from the function on
        # which it was applied to the wrapped object, so we must transfer it here
        # explicitly.
        polymorphic_fn.__doc__ = getattr(fn, '__doc__', None)
        return polymorphic_fn

    # Either we have a concrete parameter type, or this is no-arg function.
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
    py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
                            'value returned by the wrapper')
    if not type_analysis.are_equivalent_types(
            concrete_fn.type_signature.parameter, parameter_type):
        raise TypeError(
            'Expected a concrete function that takes parameter {}, got one '
            'that takes {}.'.format(str(parameter_type),
                                    str(concrete_fn.type_signature.parameter)))
    # When applying a decorator, the __doc__ attribute with the documentation
    # in triple-quotes is not automatically transferred from the function on
    concrete_fn.__doc__ = getattr(fn, '__doc__', None)
    return concrete_fn