def _federated_computation_wrapper_fn(parameter_type, name): """Wrapper function to plug orchestration logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ ctx_stack = context_stack_impl.context_stack if parameter_type is None: parameter_name = None else: parameter_name = 'arg' fn_generator = federated_computation_utils.federated_computation_serializer( parameter_name=parameter_name, parameter_type=parameter_type, context_stack=ctx_stack, suggested_name=name) arg = next(fn_generator) try: result = yield arg except Exception as e: # pylint: disable=broad-except fn_generator.throw(e) target_lambda, extra_type_spec = fn_generator.send(result) fn_generator.close() yield computation_impl.ComputationImpl(target_lambda.proto, ctx_stack, extra_type_spec)
def test_returns_value_with_computation_impl(self, proto, type_signature): executor = create_test_executor() value = computation_impl.ComputationImpl( proto, context_stack_impl.context_stack) result = self.run_sync(executor.create_value(value, type_signature)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), type_signature.compact_representation())
def test_set_local_execution_context_and_run_simple_xla_computation(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) ctx_stack = context_stack_impl.context_stack comp = computation_impl.ComputationImpl(comp_pb, ctx_stack) execution_contexts.set_local_execution_context() self.assertEqual(comp(), 10)
def test_invoke_raises_value_error_with_federated_computation(self): bogus_proto = pb.Computation(type=type_serialization.serialize_type( computation_types.to_type( computation_types.FunctionType(tf.int32, tf.int32))), reference=pb.Reference(name='boogledy')) non_tf_computation = computation_impl.ComputationImpl( bogus_proto, context_stack_impl.context_stack) context = tensorflow_computation_context.TensorFlowComputationContext( tf.compat.v1.get_default_graph()) with self.assertRaisesRegex( ValueError, 'Can only invoke TensorFlow in the body of ' 'a TensorFlow computation'): context.invoke(non_tf_computation, None)
def deserialize_computation( computation_proto: pb.Computation) -> computation_base.Computation: """Deserializes 'tff.Computation' as a pb.Computation. Args: computation_proto: An instance of `pb.Computation`. Returns: The corresponding instance of `tff.Computation`. Raises: TypeError: If the argument is of the wrong type. """ py_typecheck.check_type(computation_proto, pb.Computation) return computation_impl.ComputationImpl(computation_proto, context_stack_impl.context_stack)
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack): """Serializes a Python function containing JAX code as a TFF computation. Args: fn_to_wrap: The Python function containing JAX code to be serialized as a computation containing XLA. fn_name: The name for the constructed computation (currently ignored). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if there's none. unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`. Returns: An instance of `computation_impl.ComputationImpl` with the constructed computation. """ del fn_name # Unused. unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=unpack) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack)
def test_something(self): # TODO(b/113112108): Revise these tests after a more complete implementation # is in place. # At the moment, this should succeed, as both the computation body and the # type are well-formed. computation_impl.ComputationImpl( pb.Computation( **{ 'type': type_serialization.serialize_type( computation_types.FunctionType(tf.int32, tf.int32)), 'intrinsic': pb.Intrinsic(uri='whatever') }), context_stack_impl.context_stack) # This should fail, as the proto is not well-formed. self.assertRaises(TypeError, computation_impl.ComputationImpl, pb.Computation(), context_stack_impl.context_stack) # This should fail, as "10" is not an instance of pb.Computation. self.assertRaises(TypeError, computation_impl.ComputationImpl, 10, context_stack_impl.context_stack)
def _tf_wrapper_fn(parameter_type, name): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError( '`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack tf_serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, ctx_stack) arg = next(tf_serializer) try: result = yield arg except Exception as e: # pylint: disable=broad-except tf_serializer.throw(e) comp_pb, extra_type_spec = tf_serializer.send(result) yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def building_block_to_computation(building_block): """Converts a computation building block to a computation impl.""" py_typecheck.check_type(building_block, building_blocks.ComputationBuildingBlock) return computation_impl.ComputationImpl(building_block.proto, context_stack_impl.context_stack)