Exemple #1
0
def _federated_computation_wrapper_fn(parameter_type, name):
    """Wrapper function to plug orchestration logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
    ctx_stack = context_stack_impl.context_stack
    if parameter_type is None:
        parameter_name = None
    else:
        parameter_name = 'arg'
    fn_generator = federated_computation_utils.federated_computation_serializer(
        parameter_name=parameter_name,
        parameter_type=parameter_type,
        context_stack=ctx_stack,
        suggested_name=name)
    arg = next(fn_generator)
    try:
        result = yield arg
    except Exception as e:  # pylint: disable=broad-except
        fn_generator.throw(e)
    target_lambda, extra_type_spec = fn_generator.send(result)
    fn_generator.close()
    yield computation_impl.ComputationImpl(target_lambda.proto, ctx_stack,
                                           extra_type_spec)
    def test_returns_value_with_computation_impl(self, proto, type_signature):
        executor = create_test_executor()
        value = computation_impl.ComputationImpl(
            proto, context_stack_impl.context_stack)

        result = self.run_sync(executor.create_value(value, type_signature))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        self.assertEqual(result.type_signature.compact_representation(),
                         type_signature.compact_representation())
 def test_set_local_execution_context_and_run_simple_xla_computation(self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     ctx_stack = context_stack_impl.context_stack
     comp = computation_impl.ComputationImpl(comp_pb, ctx_stack)
     execution_contexts.set_local_execution_context()
     self.assertEqual(comp(), 10)
Exemple #4
0
    def test_invoke_raises_value_error_with_federated_computation(self):
        bogus_proto = pb.Computation(type=type_serialization.serialize_type(
            computation_types.to_type(
                computation_types.FunctionType(tf.int32, tf.int32))),
                                     reference=pb.Reference(name='boogledy'))
        non_tf_computation = computation_impl.ComputationImpl(
            bogus_proto, context_stack_impl.context_stack)

        context = tensorflow_computation_context.TensorFlowComputationContext(
            tf.compat.v1.get_default_graph())

        with self.assertRaisesRegex(
                ValueError, 'Can only invoke TensorFlow in the body of '
                'a TensorFlow computation'):
            context.invoke(non_tf_computation, None)
def deserialize_computation(
        computation_proto: pb.Computation) -> computation_base.Computation:
    """Deserializes 'tff.Computation' as a pb.Computation.

  Args:
    computation_proto: An instance of `pb.Computation`.

  Returns:
    The corresponding instance of `tff.Computation`.

  Raises:
    TypeError: If the argument is of the wrong type.
  """
    py_typecheck.check_type(computation_proto, pb.Computation)
    return computation_impl.ComputationImpl(computation_proto,
                                            context_stack_impl.context_stack)
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    fn_to_wrap: The Python function containing JAX code to be serialized as a
      computation containing XLA.
    fn_name: The name for the constructed computation (currently ignored).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if there's none.
    unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`.

  Returns:
    An instance of `computation_impl.ComputationImpl` with the constructed
    computation.
  """
    del fn_name  # Unused.
    unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
        fn_to_wrap, parameter_type, unpack=unpack)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack)
    def test_something(self):
        # TODO(b/113112108): Revise these tests after a more complete implementation
        # is in place.

        # At the moment, this should succeed, as both the computation body and the
        # type are well-formed.
        computation_impl.ComputationImpl(
            pb.Computation(
                **{
                    'type':
                    type_serialization.serialize_type(
                        computation_types.FunctionType(tf.int32, tf.int32)),
                    'intrinsic':
                    pb.Intrinsic(uri='whatever')
                }), context_stack_impl.context_stack)

        # This should fail, as the proto is not well-formed.
        self.assertRaises(TypeError, computation_impl.ComputationImpl,
                          pb.Computation(), context_stack_impl.context_stack)

        # This should fail, as "10" is not an instance of pb.Computation.
        self.assertRaises(TypeError, computation_impl.ComputationImpl, 10,
                          context_stack_impl.context_stack)
def _tf_wrapper_fn(parameter_type, name):
    """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
    del name  # Unused.
    if not type_analysis.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `StructType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    tf_serializer = tensorflow_serialization.tf_computation_serializer(
        parameter_type, ctx_stack)
    arg = next(tf_serializer)
    try:
        result = yield arg
    except Exception as e:  # pylint: disable=broad-except
        tf_serializer.throw(e)
    comp_pb, extra_type_spec = tf_serializer.send(result)
    yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
Exemple #9
0
def building_block_to_computation(building_block):
    """Converts a computation building block to a computation impl."""
    py_typecheck.check_type(building_block,
                            building_blocks.ComputationBuildingBlock)
    return computation_impl.ComputationImpl(building_block.proto,
                                            context_stack_impl.context_stack)