def basic_federated_select_args(self): values = ['first', 'second', 'third'] server_val = intrinsics.federated_value(values, placements.SERVER) max_key_py = len(values) max_key = intrinsics.federated_value(max_key_py, placements.SERVER) def get_three_random_keys_fn(): return tf.random.uniform(shape=[3], minval=0, maxval=max_key_py, dtype=tf.int32) get_three_random_keys_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( get_three_random_keys_fn, None) get_three_random_keys = computation_impl.ConcreteComputation( get_three_random_keys_proto, context_stack_impl.context_stack) client_keys = intrinsics.federated_eval(get_three_random_keys, placements.CLIENTS) state_type = server_val.type_signature.member def _select_fn(arg): state = type_conversions.type_to_py_container(arg[0], state_type) key = arg[1] return tf.gather(state, key) select_fn_type = computation_types.StructType([state_type, tf.int32]) select_fn_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _select_fn, select_fn_type) select_fn = computation_impl.ConcreteComputation( select_fn_proto, context_stack_impl.context_stack) return (client_keys, max_key, server_val, select_fn)
def test_federated_aggregate_with_unknown_dimension(self): Accumulator = collections.namedtuple('Accumulator', ['samples']) # pylint: disable=invalid-name accumulator_type = computation_types.to_type( Accumulator(samples=computation_types.TensorType(dtype=tf.int32, shape=[None]))) x = _mock_data_of_type(computation_types.at_clients(tf.int32)) def initialize_fn(): return Accumulator(samples=tf.zeros(shape=[0], dtype=tf.int32)) initialize_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( initialize_fn, None) initialize = computation_impl.ConcreteComputation( initialize_proto, context_stack_impl.context_stack) zero = initialize() # The operator to use during the first stage simply adds an element to the # tensor, increasing its size. def _accumulate(arg): return Accumulator(samples=tf.concat( [arg[0].samples, tf.expand_dims(arg[1], axis=0)], axis=0)) accumulate_type = computation_types.StructType( [accumulator_type, tf.int32]) accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _accumulate, accumulate_type) accumulate = computation_impl.ConcreteComputation( accumulate_proto, context_stack_impl.context_stack) # The operator to use during the second stage simply adds total and count. def _merge(arg): return Accumulator( samples=tf.concat([arg[0].samples, arg[1].samples], axis=0)) merge_type = computation_types.StructType( [accumulator_type, accumulator_type]) merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _merge, merge_type) merge = computation_impl.ConcreteComputation( merge_proto, context_stack_impl.context_stack) # The operator to use during the final stage simply computes the ratio. report_proto, _ = tensorflow_computation_factory.create_identity( accumulator_type) report = computation_impl.ConcreteComputation( report_proto, context_stack_impl.context_stack) value = intrinsics.federated_aggregate(x, zero, accumulate, merge, report) self.assert_value(value, '<samples=int32[?]>@SERVER')
def test_federated_aggregate_with_client_int(self): # The representation used during the aggregation process will be a named # tuple with 2 elements - the integer 'total' that represents the sum of # elements encountered, and the integer element 'count'. Accumulator = collections.namedtuple('Accumulator', 'total count') # pylint: disable=invalid-name accumulator_type = computation_types.to_type( Accumulator(total=computation_types.TensorType(dtype=tf.int32), count=computation_types.TensorType(dtype=tf.int32))) x = _mock_data_of_type(computation_types.at_clients(tf.int32)) zero = Accumulator(0, 0) # The operator to use during the first stage simply adds an element to the # total and updates the count. def _accumulate(arg): return Accumulator(arg[0].total + arg[1], arg[0].count + 1) accumulate_type = computation_types.StructType( [accumulator_type, tf.int32]) accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _accumulate, accumulate_type) accumulate = computation_impl.ConcreteComputation( accumulate_proto, context_stack_impl.context_stack) # The operator to use during the second stage simply adds total and count. def _merge(arg): return Accumulator(arg[0].total + arg[1].total, arg[0].count + arg[1].count) merge_type = computation_types.StructType( [accumulator_type, accumulator_type]) merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _merge, merge_type) merge = computation_impl.ConcreteComputation( merge_proto, context_stack_impl.context_stack) # The operator to use during the final stage simply computes the ratio. def _report(arg): return tf.cast(arg.total, tf.float32) / tf.cast( arg.count, tf.float32) report_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _report, accumulator_type) report = computation_impl.ConcreteComputation( report_proto, context_stack_impl.context_stack) value = intrinsics.federated_aggregate(x, zero, accumulate, merge, report) self.assert_value(value, 'float32@SERVER')
def test_federated_select_autozips_server_val(self, federated_select): client_keys, max_key, server_val, select_fn = ( self.basic_federated_select_args()) del server_val, select_fn values = ['first', 'second', 'third'] server_val_element = intrinsics.federated_value( values, placements.SERVER) server_val_dict = collections.OrderedDict(e1=server_val_element, e2=server_val_element) state_type = computation_types.StructType([ ('e1', server_val_element.type_signature.member), ('e2', server_val_element.type_signature.member), ]) def _select_fn_dict(arg): state = type_conversions.type_to_py_container(arg[0], state_type) key = arg[1] return (tf.gather(state['e1'], key), tf.gather(state['e2'], key)) select_fn_dict_type = computation_types.StructType( [state_type, tf.int32]) select_fn_dict_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _select_fn_dict, select_fn_dict_type) select_fn_dict = computation_impl.ConcreteComputation( select_fn_dict_proto, context_stack_impl.context_stack) result = federated_select(client_keys, max_key, server_val_dict, select_fn_dict) self.assert_value(result, '{<string,string>*}@CLIENTS')
def call_concrete(*args): concrete = computation_impl.ConcreteComputation( comp.proto, context_stack_impl.context_stack) result = concrete(*args) if comp.type_signature.result.is_struct(): return structure.from_container(result, recursive=True) return result
def test_returns_value_with_computation_impl(self, proto, type_signature): executor = create_test_executor() value = computation_impl.ConcreteComputation( proto, context_stack_impl.context_stack) result = self.run_sync(executor.create_value(value, type_signature)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assertEqual(result.type_signature.compact_representation(), type_signature.compact_representation())
def _create_computation_greater_than_10_with_unused_parameter( ) -> computation_base.Computation: parameter_type = computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.TensorType(tf.int32), ]) computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( lambda arg: arg[0] > 10, parameter_type) return computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack)
def test_serialize_deserialize_round_trip(self): operand_type = computation_types.TensorType(tf.int32) proto, _ = tensorflow_computation_factory.create_binary_operator( tf.add, operand_type, operand_type) comp = computation_impl.ConcreteComputation( proto, context_stack_impl.context_stack) serialized_comp = computation_serialization.serialize_computation(comp) deserialize_comp = computation_serialization.deserialize_computation( serialized_comp) self.assertIsInstance(deserialize_comp, computation_base.Computation) self.assertEqual(deserialize_comp, comp)
def test_to_value_for_computations(self): tensor_type = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_constant( 10, tensor_type) computation = computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack) value = value_impl.to_value(computation, None) self.assertIsInstance(value, value_impl.Value) self.assertEqual(value.type_signature.compact_representation(), '( -> int32)')
def test_intrinsic_construction_raises_outside_symbol_binding_context( self): type_signature = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_constant( 2, type_signature) return_2 = computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack) with context_stack_impl.context_stack.install( runtime_error_context.RuntimeErrorContext()): with self.assertRaises(context_base.ContextError): intrinsics.federated_eval(return_2, placements.SERVER)
def test_infers_accumulate_return_as_merge_arg_merge_return_as_report_arg( self): type_spec = computation_types.TensorType(dtype=tf.int64, shape=[None]) x = _mock_data_of_type(computation_types.at_clients(tf.int64)) def initialize_fn(): return tf.constant([], dtype=tf.int64, shape=[0]) initialize_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( initialize_fn, None) initialize = computation_impl.ConcreteComputation( initialize_proto, context_stack_impl.context_stack) zero = initialize() def _accumulate(arg): return tf.concat([arg[0], [arg[1]]], 0) accumulate_type = computation_types.StructType([type_spec, tf.int64]) accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _accumulate, accumulate_type) accumulate = computation_impl.ConcreteComputation( accumulate_proto, context_stack_impl.context_stack) def _merge(arg): return tf.concat([arg[0], arg[1]], 0) merge_type = computation_types.StructType([type_spec, type_spec]) merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _merge, merge_type) merge = computation_impl.ConcreteComputation( merge_proto, context_stack_impl.context_stack) report_proto, _ = tensorflow_computation_factory.create_identity( type_spec) report = computation_impl.ConcreteComputation( report_proto, context_stack_impl.context_stack) value = intrinsics.federated_aggregate(x, zero, accumulate, merge, report) self.assert_value(value, 'int64[?]@SERVER')
def test_invoke_returns_value_with_correct_type(self): tensor_type = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_constant( 10, tensor_type) computation = computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack) context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) with context_stack_impl.context_stack.install(context): result = context.invoke(computation, None) self.assertIsInstance(result, value_impl.Value) self.assertEqual(str(result.type_signature), 'int32')
def test_federated_select_keys_must_be_fixed_length( self, federated_select): client_keys, max_key, server_val, select_fn = ( self.basic_federated_select_args()) unshape_type = computation_types.TensorType(tf.int32, [None]) unshape_proto, _ = tensorflow_computation_factory.create_identity( unshape_type) unshape = computation_impl.ConcreteComputation( unshape_proto, context_stack_impl.context_stack) bad_client_keys = intrinsics.federated_map(unshape, client_keys) with self.assertRaises(TypeError): federated_select(bad_client_keys, max_key, server_val, select_fn)
def test_set_local_python_execution_context_and_run_simple_xla_computation( self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) ctx_stack = context_stack_impl.context_stack comp = computation_impl.ConcreteComputation(comp_pb, ctx_stack) execution_contexts.set_local_python_execution_context() self.assertEqual(comp(), 10)
def test_with_type_raises_non_assignable_type(self): int_return_type = computation_types.FunctionType(tf.int32, tf.int32) original_comp = computation_impl.ConcreteComputation( pb.Computation( **{ 'type': type_serialization.serialize_type(int_return_type), 'intrinsic': pb.Intrinsic(uri='whatever') }), context_stack_impl.context_stack) list_return_type = computation_types.FunctionType( tf.int32, computation_types.StructWithPythonType([(None, tf.int32)], list)) with self.assertRaises(computation_types.TypeNotAssignableError): computation_impl.ConcreteComputation.with_type(original_comp, list_return_type)
def test_federated_select_server_val_must_be_server_placed( self, federated_select): client_keys, max_key, server_val, select_fn = ( self.basic_federated_select_args()) del server_val bad_server_val_proto, _ = tensorflow_computation_factory.create_constant( tf.constant(['first', 'second', 'third']), computation_types.TensorType(dtype=tf.string, shape=[3])) bad_server_val = computation_impl.ConcreteComputation( bad_server_val_proto, context_stack_impl.context_stack) bad_server_val = bad_server_val() with self.assertRaises(TypeError): federated_select(client_keys, max_key, bad_server_val, select_fn)
def test_invoke_raises_value_error_with_federated_computation(self): bogus_proto = pb.Computation( type=type_serialization.serialize_type( computation_types.to_type( computation_types.FunctionType(tf.int32, tf.int32))), reference=pb.Reference(name='boogledy')) non_tf_computation = computation_impl.ConcreteComputation( bogus_proto, context_stack_impl.context_stack) context = tensorflow_computation_context.TensorFlowComputationContext( tf.compat.v1.get_default_graph(), tf.constant('bogus_token')) with self.assertRaisesRegex( ValueError, 'Can only invoke TensorFlow in the body of ' 'a TensorFlow computation'): context.invoke(non_tf_computation, None)
def deserialize_computation( computation_proto: pb.Computation) -> computation_base.Computation: """Deserializes 'tff.Computation' as a pb.Computation. Args: computation_proto: An instance of `pb.Computation`. Returns: The corresponding instance of `tff.Computation`. Raises: TypeError: If the argument is of the wrong type. """ py_typecheck.check_type(computation_proto, pb.Computation) return computation_impl.ConcreteComputation(computation_proto, context_stack_impl.context_stack)
def test_with_type_preserves_python_container(self): struct_return_type = computation_types.FunctionType( tf.int32, computation_types.StructType([(None, tf.int32)])) original_comp = computation_impl.ConcreteComputation( pb.Computation( **{ 'type': type_serialization.serialize_type(struct_return_type), 'intrinsic': pb.Intrinsic(uri='whatever') }), context_stack_impl.context_stack) list_return_type = computation_types.FunctionType( tf.int32, computation_types.StructWithPythonType([(None, tf.int32)], list)) fn_with_annotated_type = computation_impl.ConcreteComputation.with_type( original_comp, list_return_type) type_test_utils.assert_types_identical( list_return_type, fn_with_annotated_type.type_signature)
def test_federated_aggregate_with_federated_zero_fails(self): x = _mock_data_of_type(computation_types.at_clients(tf.int32)) zero = intrinsics.federated_value(0, placements.SERVER) accumulate = _create_computation_add() # The operator to use during the second stage simply adds total and count. merge = _create_computation_add() # The operator to use during the final stage simply computes the ratio. report_type = computation_types.TensorType(tf.int32) report_proto, _ = tensorflow_computation_factory.create_identity( report_type) report = computation_impl.ConcreteComputation( report_proto, context_stack_impl.context_stack) with self.assertRaisesRegex( TypeError, 'Expected `zero` to be assignable to type int32, ' 'but was of incompatible type int32@SERVER'): intrinsics.federated_aggregate(x, zero, accumulate, merge, report)
def test_federated_select_keys_must_be_int32(self, federated_select): client_keys, max_key, server_val, select_fn = ( self.basic_federated_select_args()) del client_keys def get_three_random_keys_fn(): return tf.random.uniform(shape=[3], minval=0, maxval=3, dtype=tf.int64) get_three_random_keys_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( get_three_random_keys_fn, None) get_three_random_keys = computation_impl.ConcreteComputation( get_three_random_keys_proto, context_stack_impl.context_stack) bad_client_keys = intrinsics.federated_eval(get_three_random_keys, placements.CLIENTS) with self.assertRaises(TypeError): federated_select(bad_client_keys, max_key, server_val, select_fn)
def test_federated_select_fn_must_take_int32_keys(self, federated_select): client_keys, max_key, server_val, select_fn = ( self.basic_federated_select_args()) del select_fn state_type = server_val.type_signature.member def _bad_select_fn(arg): state = type_conversions.type_to_py_container(arg[0], state_type) key = arg[1] return tf.gather(state, key) bad_select_fn_type = computation_types.StructType( [state_type, tf.int64]) bad_select_fn_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( _bad_select_fn, bad_select_fn_type) bad_select_fn = computation_impl.ConcreteComputation( bad_select_fn_proto, context_stack_impl.context_stack) with self.assertRaises(TypeError): federated_select(client_keys, max_key, server_val, bad_select_fn)
def _tf_wrapper_fn(parameter_type, name): """Wrapper function to plug Tensorflow logic into the TFF framework.""" del name # Unused. if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError( '`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack tf_serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, ctx_stack) arg = next(tf_serializer) try: result = yield arg except Exception as e: # pylint: disable=broad-except tf_serializer.throw(e) comp_pb, extra_type_spec = tf_serializer.send(result) tf_serializer.close() yield computation_impl.ConcreteComputation(comp_pb, ctx_stack, extra_type_spec)
def test_something(self): # TODO(b/113112108): Revise these tests after a more complete implementation # is in place. # At the moment, this should succeed, as both the computation body and the # type are well-formed. computation_impl.ConcreteComputation( pb.Computation( **{ 'type': type_serialization.serialize_type( computation_types.FunctionType(tf.int32, tf.int32)), 'intrinsic': pb.Intrinsic(uri='whatever') }), context_stack_impl.context_stack) # This should fail, as the proto is not well-formed. self.assertRaises(TypeError, computation_impl.ConcreteComputation, pb.Computation(), context_stack_impl.context_stack) # This should fail, as "10" is not an instance of pb.Computation. self.assertRaises(TypeError, computation_impl.ConcreteComputation, 10, context_stack_impl.context_stack)
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack): """Serializes a Python function containing JAX code as a TFF computation. Args: fn_to_wrap: The Python function containing JAX code to be serialized as a computation containing XLA. fn_name: The name for the constructed computation (currently ignored). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if there's none. unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`. Returns: An instance of `computation_impl.ConcreteComputation` with the constructed computation. """ del fn_name # Unused. unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=unpack) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation(fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack) return computation_impl.ConcreteComputation(comp_pb, ctx_stack)
def _create_computation_add() -> computation_base.Computation: operand_type = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_binary_operator( tf.add, operand_type, operand_type) return computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack)
def _create_computation_greater_than_10() -> computation_base.Computation: parameter_type = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( lambda x: x > 10, parameter_type) return computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack)
def _create_computation_random() -> computation_base.Computation: computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( lambda: tf.random.normal([]), None) return computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack)
def _create_computation_reduce() -> computation_base.Computation: parameter_type = computation_types.SequenceType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( lambda ds: ds.reduce(np.int32(0), lambda x, y: x + y), parameter_type) return computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack)