Esempio n. 1
0
    def test_fetch_value_with_dataset_and_tensor(self):
        def return_dataset_and_tensor():
            return [tf.constant(0), tf.data.Dataset.range(5), tf.constant(5)]

        executable_return_dataset_and_tensor = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset_and_tensor, None,
                context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset_and_tensor()
        self.assertEqual(x[0], 0)
        self.assertEqual(x[1], list(range(5)))
        self.assertEqual(x[2], 5)
Esempio n. 2
0
  def test_fetch_value_with_empty_dataset_and_tensors(self):

    def return_dataset():
      ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]])
      return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

    executable_return_dataset = computation_impl.ComputationImpl(
        tensorflow_serialization.serialize_py_fn_as_tf_computation(
            return_dataset, None, context_stack_impl.context_stack)[0],
        context_stack_impl.context_stack)

    x = executable_return_dataset()
    self.assertEqual(x[0][0], 0.)
    self.assertEqual(x[0][1], 0.)
    self.assertEqual(str(x[1][0]), str(np.zeros([0, 2], dtype=np.int32)))
Esempio n. 3
0
def _federated_computation_wrapper_fn(parameter_type, name):
  """Wrapper function to plug orchestration logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  ctx_stack = context_stack_impl.context_stack
  fn_generator = federated_computation_utils.federated_computation_serializer(
      'arg' if parameter_type else None,
      parameter_type,
      ctx_stack,
      suggested_name=name)
  result = yield next(fn_generator)
  target_lambda, extra_type_spec = fn_generator.send(result)
  yield computation_impl.ComputationImpl(target_lambda.proto, ctx_stack,
                                         extra_type_spec)
Esempio n. 4
0
    def test_fetch_value_with_empty_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]])
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertAllEqual(x[0], [0., 0.])
        self.assertEqual(x[1].element_spec,
                         tf.TensorSpec(shape=(None, 2), dtype=tf.int32))
        with self.assertRaises(StopIteration):
            _ = next(iter(x[1]))
def _federated_computation_wrapper_fn(target_fn,
                                      parameter_type,
                                      unpack,
                                      name=None):
    """Wrapper function to plug orchestration logic in to TFF framework."""
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    ctx_stack = context_stack_impl.context_stack
    target_lambda = (
        federated_computation_utils.zero_or_one_arg_fn_to_building_block(
            target_fn,
            'arg' if parameter_type else None,
            parameter_type,
            ctx_stack,
            suggested_name=name))
    return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
    """Wrapper function to plug Tensorflow logic in to TFF framework."""
    del name  # Unused.
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    if not type_utils.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `NamedTupleType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        target_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack,
                                            extra_type_spec)
Esempio n. 7
0
  def test_fetch_value_with_datasets_nested_at_second_level(self):

    def return_two_datasets():
      return [
          tf.constant(0), [tf.data.Dataset.range(5),
                           tf.data.Dataset.range(5)]
      ]

    executable_return_two_datasets = computation_impl.ComputationImpl(
        tensorflow_serialization.serialize_py_fn_as_tf_computation(
            return_two_datasets, None, context_stack_impl.context_stack)[0],
        context_stack_impl.context_stack)

    x = executable_return_two_datasets()
    self.assertEqual(x[0], 0)
    self.assertEqual(x[1][0], list(range(5)))
    self.assertEqual(x[1][1], list(range(5)))
Esempio n. 8
0
    def compile(self, computation_to_compile):
        """Compiles `computation_to_compile`.

    Args:
      computation_to_compile: An instance of `computation_base.Computation` to
        compile.

    Returns:
      An instance of `computation_base.Computation` that repeesents the result.
    """
        py_typecheck.check_type(computation_to_compile,
                                computation_base.Computation)
        computation_proto = computation_impl.ComputationImpl.get_proto(
            computation_to_compile)
        py_typecheck.check_type(computation_proto, pb.Computation)
        comp = building_blocks.ComputationBuildingBlock.from_proto(
            computation_proto)

        # TODO(b/113123410): Add a compiler options argument that characterizes the
        # desired form of the output. To be driven by what the specific backend the
        # pipeline is targeting is able to understand. Pending a more fleshed out
        # design of the backend API.

        # Replace intrinsics with their bodies, for now manually in a fixed order.
        # TODO(b/113123410): Replace this with a more automated implementation that
        # does not rely on manual maintenance.
        comp, _ = value_transformations.replace_all_intrinsics_with_bodies(
            comp, self._context_stack)

        # Replaces called lambdas with LET constructs with a single local symbol.
        comp, _ = transformations.replace_called_lambda_with_block(comp)

        # Removes maped or applied identities.
        comp, _ = transformations.remove_mapped_or_applied_identity(comp)

        # Remove duplicate computations. This is important! otherwise the semantics
        # non-deterministic computations (e.g. a `tff.tf_computation` depending on
        # `tf.random`) will give unexpected behavior. Additionally, this may reduce
        # the amount of calls into TF for some ASTs.
        comp, _ = transformations.uniquify_reference_names(comp)
        comp, _ = transformations.extract_computations(comp)
        comp, _ = transformations.remove_duplicate_computations(comp)

        return computation_impl.ComputationImpl(comp.proto,
                                                self._context_stack)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
  """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  del name  # Unused.
  target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
      target_fn, parameter_type, unpack)
  if not type_analysis.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
      target_fn, parameter_type, ctx_stack)
  return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
Esempio n. 10
0
def _tf_wrapper_fn(parameter_type, name):
  """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  del name  # Unused.
  if not type_analysis.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  tf_serializer = tensorflow_serialization.tf_computation_serializer(
      parameter_type, ctx_stack)
  result = yield next(tf_serializer)
  comp_pb, extra_type_spec = tf_serializer.send(result)
  yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
Esempio n. 11
0
    def test_fetch_value_with_empty_structured_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices(
                collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])]))
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertEqual(x[0][0], 0.)
        self.assertEqual(x[0][1], 0.)
        self.assertTrue(
            np.array_equal(x[1][0].a, np.zeros([0], dtype=np.int32)))
        self.assertTrue(
            np.array_equal(x[1][0].b, np.zeros([0], dtype=np.int32)))
Esempio n. 12
0
    def compile(self, computation_to_compile):
        """Compiles `computation_to_compile`.

    Args:
      computation_to_compile: An instance of `computation_base.Computation` to
        compile.

    Returns:
      An instance of `computation_base.Computation` that repeesents the result.
    """
        py_typecheck.check_type(computation_to_compile,
                                computation_base.Computation)
        computation_proto = computation_impl.ComputationImpl.get_proto(
            computation_to_compile)

        # TODO(b/113123410): Add a compiler options argument that characterizes the
        # desired form of the output. To be driven by what the specific backend the
        # pipeline is targeting is able to understand. Pending a more fleshed out
        # design of the backend API.

        py_typecheck.check_type(computation_proto, pb.Computation)
        comp = computation_building_blocks.ComputationBuildingBlock.from_proto(
            computation_proto)

        # Replace intrinsics with their bodies, for now manually in a fixed order.
        # TODO(b/113123410): Replace this with a more automated implementation that
        # does not rely on manual maintenance.
        for uri, body in six.iteritems(self._intrinsic_bodies):
            comp, _ = transformations.replace_intrinsic_with_callable(
                comp, uri, body, self._context_stack)

        # Replaces called lambdas with LET constructs with a single local symbol.
        comp, _ = transformations.replace_called_lambda_with_block(comp)
        # TODO(b/113123410): Add more transformations to simplify and optimize the
        # structure, e.g., such as:
        # * removing unnecessary lambdas,
        # * flatteting the structure,
        # * merging TensorFlow blocks where appropriate,
        # * ...and so on.

        return computation_impl.ComputationImpl(comp.proto,
                                                self._context_stack)
def _federated_computation_wrapper_fn(target_fn,
                                      parameter_type,
                                      unpack,
                                      name=None):
    """Wrapper function to plug orchestration logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    ctx_stack = context_stack_impl.context_stack
    target_lambda = (
        federated_computation_utils.zero_or_one_arg_fn_to_building_block(
            target_fn,
            'arg' if parameter_type else None,
            parameter_type,
            ctx_stack,
            suggested_name=name))
    return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)
Esempio n. 14
0
    def test_fetch_value_with_empty_structured_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices(
                collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])]))
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertAllEqual(x[0], [0., 0.])
        self.assertEqual(
            tf.data.experimental.get_structure(x[1]),
            collections.OrderedDict([
                ('a', tf.TensorSpec(shape=(None, ), dtype=tf.int32)),
                ('b', tf.TensorSpec(shape=(None, ), dtype=tf.int32)),
            ]))
        with self.assertRaises(StopIteration):
            _ = next(iter(x[1]))
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    fn_to_wrap: The Python function containing JAX code to be serialized as a
      computation containing XLA.
    fn_name: The name for the constructed computation (currently ignored).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if there's none.
    unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`.

  Returns:
    An instance of `computation_impl.ComputationImpl` with the constructed
    computation.
  """
    del fn_name  # Unused.
    unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
        fn_to_wrap, parameter_type, unpack=unpack)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack)
Esempio n. 16
0
  def test_something(self):
    # TODO(b/113112108): Revise these tests after a more complete implementation
    # is in place.

    # At the moment, this should succeed, as both the computation body and the
    # type are well-formed.
    computation_impl.ComputationImpl(
        pb.Computation(
            **{
                'type':
                    type_serialization.serialize_type(
                        computation_types.FunctionType(tf.int32, tf.int32)),
                'intrinsic':
                    pb.Intrinsic(uri='whatever')
            }), context_stack_impl.context_stack)

    # This should fail, as the proto is not well-formed.
    self.assertRaises(TypeError, computation_impl.ComputationImpl,
                      pb.Computation(), context_stack_impl.context_stack)

    # This should fail, as "10" is not an instance of pb.Computation.
    self.assertRaises(TypeError, computation_impl.ComputationImpl, 10,
                      context_stack_impl.context_stack)
def _federated_computation_wrapper_fn(parameter_type, name):
  """Wrapper function to plug orchestration logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  ctx_stack = context_stack_impl.context_stack
  if parameter_type is None:
    parameter_name = None
  else:
    parameter_name = 'arg'
  fn_generator = federated_computation_utils.federated_computation_serializer(
      parameter_name=parameter_name,
      parameter_type=parameter_type,
      context_stack=ctx_stack,
      suggested_name=name)
  arg = next(fn_generator)
  try:
    result = yield arg
  except Exception as e:  # pylint: disable=broad-except
    fn_generator.throw(e)
  target_lambda, extra_type_spec = fn_generator.send(result)
  yield computation_impl.ComputationImpl(target_lambda.proto, ctx_stack,
                                         extra_type_spec)
Esempio n. 18
0
def create_dummy_computation_impl():
    """Returns a `tff.ComputationImpl` and type."""
    comp, type_signature = create_dummy_computation_tensorflow_identity()
    value = computation_impl.ComputationImpl(comp,
                                             context_stack_impl.context_stack)
    return value, type_signature
def create_dummy_computation_impl():
    proto = executor_test_utils.create_dummy_identity_lambda_computation()
    value = computation_impl.ComputationImpl(proto,
                                             context_stack_impl.context_stack)
    type_signature = type_factory.unary_op(tf.int32)
    return value, type_signature
Esempio n. 20
0
def building_block_to_computation(building_block):
    """Converts a computation building block to a computation impl."""
    py_typecheck.check_type(building_block,
                            building_blocks.ComputationBuildingBlock)
    return computation_impl.ComputationImpl(building_block.proto,
                                            context_stack_impl.context_stack)
Esempio n. 21
0
def _to_computation_impl(building_block):
    return computation_impl.ComputationImpl(building_block.proto,
                                            context_stack_impl.context_stack)