def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring
        param_elements = []
        if arg is not None:
            arg_parts = structure.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return structure.pack_sequence_as(result_type, result_elements)
示例#2
0
 def _train_on_one_batch(model, batch):
     params = structure.flatten(
         structure.from_container(model, recursive=True))
     grads = structure.flatten(
         structure.from_container(jax.api.grad(loss_fn)(model, batch)))
     updated_params = [_apply_update(x, y) for (x, y) in zip(params, grads)]
     trained_model = structure.pack_sequence_as(model_type, updated_params)
     return type_conversions.type_to_py_container(trained_model, model_type)
示例#3
0
 async def _comp(x):
     if isinstance(x, (_Sequence, executor_value_base.ExecutorValue)):
         return await x.compute()
     if isinstance(x, structure.Struct):
         return structure.pack_sequence_as(
             x, await
             asyncio.gather(*[_comp(y) for y in structure.flatten(x)]))
     raise NotImplementedError(
         'Unable to compute a value of type {}.'.format(
             py_typecheck.type_string(type(x))))
示例#4
0
 def test_pack_sequence_as_fails_non_struct(self):
   x = structure.Struct([
       ('a', 10),
       ('b', {
           'd': 20
       }),
       ('c', 30),
   ])
   y = [10, 20, 30]
   with self.assertRaisesRegex(TypeError, 'Cannot pack sequence'):
     _ = structure.pack_sequence_as(x, y)
示例#5
0
    def __call__(self, *args, **kwargs):
        """Invokes this callable with the given set of arguments.

    Args:
      *args: Positional arguments.
      **kwargs: Keyword arguments.

    Returns:
      The result of the call.

    Raises:
      ValueError: if the arguments or the result are incorrect.
    """
        if kwargs:
            raise ValueError('Not expecting keyword arguments.')
        if len(args) > 1:
            raise ValueError(
                'Not expecting more than one positional argument.')
        param_type = self.type_signature.parameter
        if param_type is None:
            if len(args) > 0:  # pylint: disable=g-explicit-length-test
                raise ValueError('Not expecting any arguments.')
            else:
                flat_py_args = []
        else:
            if len(args) == 0:  # pylint: disable=g-explicit-length-test
                raise ValueError('Positional argument missing.')
            positional_arg = args[0]
            if isinstance(param_type, computation_types.TensorType):
                flat_py_args = [positional_arg]
            else:
                py_typecheck.check_type(param_type,
                                        computation_types.StructType)
                py_typecheck.check_type(positional_arg, structure.Struct)
                flat_py_args = structure.flatten(positional_arg)

        reordered_flat_py_args = [
            flat_py_args[idx]
            for idx in self._inverted_parameter_tensor_indexes
        ]

        unordered_result = xla_client.execute_with_python_values(
            self._executable, reordered_flat_py_args, self._backend)
        py_typecheck.check_type(unordered_result, list)
        result = [unordered_result[idx] for idx in self._result_tensor_indexes]
        result_type = self.type_signature.result
        if isinstance(result_type, computation_types.TensorType):
            if len(result) != 1:
                raise ValueError('Expected one result, found {}.'.format(
                    len(result)))
            return normalize_tensor_representation(result[0], result_type)
        else:
            py_typecheck.check_type(result_type, computation_types.StructType)
            return structure.pack_sequence_as(result_type, result)
示例#6
0
 def test_flatten_and_pack_sequence_as(self):
   x = structure.Struct.named(
       a=10,
       b=structure.Struct.named(
           x=structure.Struct.named(p=40),
           y=30,
           z=structure.Struct.named(q=50, r=60)),
       c=20,
   )
   y = structure.flatten(x)
   self.assertEqual(y, [10, 40, 30, 50, 60, 20])
   z = structure.pack_sequence_as(x, y)
   self.assertEqual(str(z), '<a=10,b=<x=<p=40>,y=30,z=<q=50,r=60>>,c=20>')
示例#7
0
 def test_flatten_and_pack_sequence_as(self):
   x = structure.Struct([
       ('a', 10),
       ('b',
        structure.Struct([
            ('x', structure.Struct([('p', 40)])),
            ('y', 30),
            ('z', structure.Struct([('q', 50), ('r', 60)])),
        ])),
       ('c', 20),
   ])
   y = structure.flatten(x)
   self.assertEqual(y, [10, 40, 30, 50, 60, 20])
   z = structure.pack_sequence_as(x, y)
   self.assertEqual(str(z), '<a=10,b=<x=<p=40>,y=30,z=<q=50,r=60>>,c=20>')
示例#8
0
    async def create_value(self, value, type_spec=None):
        """Creates a value in this executor.

    The following kinds of `value` are supported as the input:

    * An instance of a TFF computation proto that represents a `data` building
      block, to be handled natively by this executor.

    * Anything that is supported by the target executor (as a pass-through).

    * A nested structure of any of the above.

    Args:
      value: The input for which to create a value.
      type_spec: An optional TFF type of `value`.

    Returns:
      A value embedded in the target executor.
    """
        if isinstance(value, pb.Computation):
            if value.WhichOneof('computation') == 'data':
                value_type = type_serialization.deserialize_type(value.type)
                if type_spec is not None:
                    type_spec.check_equivalent_to(value_type)
                else:
                    type_spec = value_type
                payload = await self._data_backend.materialize(
                    value.data, type_spec)
                return await self._target_executor.create_value(
                    payload, type_spec)
            else:
                return await self._target_executor.create_value(
                    value, type_spec)
        elif isinstance(type_spec, computation_types.StructType):
            if not isinstance(value, structure.Struct):
                value = structure.from_container(value)
            elements = structure.flatten(value)
            element_types = structure.flatten(type_spec)
            flat_embedded_vals = await asyncio.gather(*[
                self.create_value(el, el_type)
                for el, el_type in zip(elements, element_types)
            ])
            embedded_struct = structure.pack_sequence_as(
                value, flat_embedded_vals)
            return await self._target_executor.create_struct(embedded_struct)
        else:
            return await self._target_executor.create_value(value, type_spec)
示例#9
0
    def _fn_to_return(arg, param_fns, wrapped_fn):  # pylint:disable=missing-docstring

        # TODO(b/166479382): This cleanup-before-invocation pattern is a workaround
        # to square the circle of TF data expecting to lazily reference this
        # resource on iteration, as well as usages that expect to reinitialize a
        # table with new data. Revisit the semantics implied by this cleanup
        # pattern.
        eager_cleanup_resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'HashTableV2':
                eager_cleanup_resources += op.outputs
        if eager_cleanup_resources:
            for resource in wrapped_fn.prune(
                    feeds={}, fetches=eager_cleanup_resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        param_elements = []
        if arg is not None:
            arg_parts = structure.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
        result_parts = wrapped_fn(*param_elements)

        # There is a tf.wrap_function(...) issue b/144127474 that variables created
        # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
        # destroyed.  So get all the variables from `wrapped_fn` and destroy
        # manually.
        # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
        # is fixed.
        resources = []
        for op in wrapped_fn.graph.get_operations():
            if op.type == 'VarHandleOp':
                resources += op.outputs
        if resources:
            for resource in wrapped_fn.prune(feeds={}, fetches=resources)():
                tf.raw_ops.DestroyResourceOp(resource=resource)

        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return structure.pack_sequence_as(result_type, result_elements)
示例#10
0
def _call_embedded_tf(*, arg, param_fns, result_fns, result_type, wrapped_fn,
                      destroy_before_invocation, destroy_after_invocation):
    """Function to be run upon EagerTFExecutor.create_call invocation.

  As this function is run completely synchronously, and
  `EagerTFExecutor.create_call` invocations represent the main work of the
  program, this function should be kept as-thin a wrapper around delegation
  to the eager TensorFlow runtime as possible.

  Args:
    arg: Argument on which to invoke embedded function.
    param_fns: Functions to be applied to elements of `arg` before passing to
      `wrapped_fn`, to prepare these argument for ingestion by the eager TF
      runtime.
    result_fns: Functions to be applied to results of calling `wrapped_fn` on
      arg before re-embedding as EagerTFExecutor values.
    result_type: TFF Type signature of the result of `wrapped_fn`.
    wrapped_fn: Result of `tf.compat.v1.wrap_function` to run in the eager TF
      runtime.
    destroy_before_invocation: eager TF runtime resources which should be
      destroyed before invoking `wrapped_fn`. Examples might include hashtable
      resources.
    destroy_after_invocation: eager TF runtime resources which should be
      destroyed after invoking `wrapped_fn`. Examples might include resource
      variables.

  Returns:
    A `structure.Struct` representing the result of invoking `wrapped_fn` on
    `arg`.

  Raises:
    RuntimeError: If `arg` and `param_fns` have different numbers of elements.
  """

    # TODO(b/166479382): This cleanup-before-invocation pattern is a workaround
    # to square the circle of TF data expecting to lazily reference this
    # resource on iteration, as well as usages that expect to reinitialize a
    # table with new data. Revisit the semantics implied by this cleanup
    # pattern.
    with tracing.span('EagerTFExecutor.create_call',
                      'resource_cleanup_before_invocation',
                      span=True):
        for resource in destroy_before_invocation:
            tf.raw_ops.DestroyResourceOp(resource=resource)

    param_elements = []
    if arg is not None:
        with tracing.span('EagerTFExecutor.create_call',
                          'arg_ingestion',
                          span=True):
            arg_parts = structure.flatten(arg)
            if len(arg_parts) != len(param_fns):
                raise RuntimeError('Expected {} arguments, found {}.'.format(
                    len(param_fns), len(arg_parts)))
            for arg_part, param_fn in zip(arg_parts, param_fns):
                param_elements.append(param_fn(arg_part))
    result_parts = wrapped_fn(*param_elements)

    # There is a tf.wrap_function(...) issue b/144127474 that variables created
    # from tf.import_graph_def(...) inside tf.wrap_function(...) is not
    # destroyed.  So get all the variables from `wrapped_fn` and destroy
    # manually.
    # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...)
    # is fixed.
    with tracing.span('EagerTFExecutor.create_call',
                      'resource_cleanup_after_invocation',
                      span=True):
        for resource in destroy_after_invocation:
            tf.raw_ops.DestroyResourceOp(resource=resource)

    with tracing.span('EagerTFExecutor.create_call',
                      'result_packing',
                      span=True):
        result_elements = []
        for result_part, result_fn in zip(result_parts, result_fns):
            result_elements.append(result_fn(result_part))
        return structure.pack_sequence_as(result_type, result_elements)
示例#11
0
    flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
    tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, tensor_indexes, computation_type)


# Registers TFF's Struct as a node that Jax's tree-traversal utilities can walk
# through.
jax.tree_util.register_pytree_node(
    structure.Struct, lambda struct: (structure.flatten(struct), struct),
    lambda data, struct: structure.pack_sequence_as(data, list(struct)))
示例#12
0
def fetch_value_in_session(sess, value):
  """Fetches `value` in `session`.

  Args:
    sess: The session in which to perform the fetch (as a single run).
    value: A Python object of a form analogous to that constructed by the
      function `assemble_result_from_graph`, made of tensors and anononymous
      tuples, or a `tf.data.Dataset`.

  Returns:
    A Python object with structure similar to `value`, but with tensors
    replaced with their values, and data sets replaced with lists of their
    elements, all fetched with a single call `session.run()`.

  Raises:
    ValueError: If `value` is not a `tf.data.Dataset` or not a structure of
      tensors and anonoymous tuples.
  """
  py_typecheck.check_type(sess, tf.compat.v1.Session)
  # TODO(b/113123634): Investigate handling `list`s and `tuple`s of
  # `tf.data.Dataset`s and what the API would look like to support this.
  if isinstance(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES):
    with sess.graph.as_default():
      iterator = tf.compat.v1.data.make_one_shot_iterator(value)
      next_element = iterator.get_next()
    elements = []
    while True:
      try:
        elements.append(sess.run(next_element))
      except tf.errors.OutOfRangeError:
        break
    return elements
  else:
    flattened_value = structure.flatten(value)
    dataset_results = {}
    flat_tensors = []
    for idx, v in enumerate(flattened_value):
      if isinstance(v, type_conversions.TF_DATASET_REPRESENTATION_TYPES):
        dataset_tensors = fetch_value_in_session(sess, v)
        if not dataset_tensors:
          # An empty list has been returned; we must pack the shape information
          # back in or the result won't typecheck.
          element_structure = v.element_spec
          dummy_elem = make_dummy_element_for_type_spec(element_structure)
          dataset_tensors = [dummy_elem]
        dataset_results[idx] = dataset_tensors
      elif tf.is_tensor(v):
        flat_tensors.append(v)
      else:
        raise ValueError('Unsupported value type {}.'.format(v))
    # Note that `flat_tensors` could be an empty tuple, but it could also be a
    # list of empty tuples.
    if flat_tensors or any(x for x in flat_tensors):
      flat_computed_tensors = sess.run(flat_tensors)
    else:
      flat_computed_tensors = flat_tensors
    flattened_results = _interleave_dataset_results_and_tensors(
        dataset_results, flat_computed_tensors)

    def _to_unicode(v):
      if isinstance(v, bytes):
        return v.decode('utf-8')
      return v

    if tf.is_tensor(value) and value.dtype == tf.string:
      flattened_results = [_to_unicode(result) for result in flattened_results]
    return structure.pack_sequence_as(value, flattened_results)
示例#13
0
    async def create_value(self, value, type_spec=None):
        """Creates a value in this executor.

    The following kinds of `value` are supported as the input:

    * An instance of TFF computation proto containing one of the supported
      sequence intrinsics as its sole body.

    * An instance of eager TF dataset.

    * Anything that is supported by the target executor (as a pass-through).

    * A nested structure of any of the above.

    Args:
      value: The input for which to create a value.
      type_spec: An optional TFF type (required if `value` is not an instance of
        `typed_object.TypedObject`, otherwise it can be `None`).

    Returns:
      An instance of `SequenceExecutorValue` that represents the embedded value.
    """
        if type_spec is None:
            py_typecheck.check_type(value, typed_object.TypedObject)
            type_spec = value.type_signature
        else:
            type_spec = computation_types.to_type(type_spec)
        if isinstance(type_spec, computation_types.SequenceType):
            return SequenceExecutorValue(
                _SequenceFromPayload(value, type_spec), type_spec)
        if isinstance(value, pb.Computation):
            value_type = type_serialization.deserialize_type(value.type)
            value_type.check_equivalent_to(type_spec)
            which_computation = value.WhichOneof('computation')
            # NOTE: If not a supported type of intrinsic, we let it fall through and
            # be handled by embedding in the target executor (below).
            if which_computation == 'intrinsic':
                intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(
                    value.intrinsic.uri)
                if intrinsic_def is None:
                    raise ValueError(
                        'Encountered an unrecognized intrinsic "{}".'.format(
                            value.intrinsic.uri))
                op_type = SequenceExecutor._SUPPORTED_INTRINSIC_TO_SEQUENCE_OP.get(
                    intrinsic_def.uri)
                if op_type is not None:
                    type_analysis.check_concrete_instance_of(
                        type_spec, intrinsic_def.type_signature)
                    op = op_type(type_spec)
                    return SequenceExecutorValue(op, type_spec)
        if isinstance(type_spec, computation_types.StructType):
            if not isinstance(value, structure.Struct):
                value = structure.from_container(value)
            elements = structure.flatten(value)
            element_types = structure.flatten(type_spec)
            flat_embedded_vals = await asyncio.gather(*[
                self.create_value(el, el_type)
                for el, el_type in zip(elements, element_types)
            ])
            embedded_struct = structure.pack_sequence_as(
                value, flat_embedded_vals)
            return await self.create_struct(embedded_struct)
        target_value = await self._target_executor.create_value(
            value, type_spec)
        return SequenceExecutorValue(target_value, type_spec)
示例#14
0
 def test_pack_sequence_as_fails_non_struct(self):
   x = structure.Struct.named(a=10, b=dict(d=20), c=30)
   y = [10, 20, 30]
   with self.assertRaisesRegex(TypeError, 'Cannot pack sequence'):
     _ = structure.pack_sequence_as(x, y)