def test_invoke_returns_result_with_tf_computation(self): make_10 = tensorflow_computation.tf_computation(lambda: tf.constant(10)) add_one = tensorflow_computation.tf_computation(lambda x: tf.add(x, 1), tf.int32) @tensorflow_computation.tf_computation def add_one_with_v1(x): v1 = tf.Variable(1, name='v1') return x + v1 @tensorflow_computation.tf_computation def add_one_with_v2(x): v2 = tf.Variable(1, name='v2') return x + v2 @tensorflow_computation.tf_computation def foo(): zero = tf.Variable(0, name='zero') ten = tf.Variable(make_10()) return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) + zero + ten - ten) with tf.compat.v1.Graph().as_default() as graph: context = tensorflow_computation_context.TensorFlowComputationContext( graph, tf.constant('bogus_token')) self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)') x = context.invoke(foo, None) with tf.compat.v1.Session(graph=graph) as sess: if context.init_ops: sess.run(context.init_ops) result = sess.run(x) self.assertEqual(result, 13)
def test_measured_process_output_as_state_raises(self): empty_output = lambda: MeasuredProcessOutput((), (), ()) initialize_fn = tensorflow_computation.tf_computation(empty_output) next_fn = tensorflow_computation.tf_computation( initialize_fn.type_signature.result)(lambda state: empty_output()) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(initialize_fn, next_fn)
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = tensorflow_computation.tf_computation()(lambda: ()) next_fn = tensorflow_computation.tf_computation( ())(lambda x: MeasuredProcessOutput(x, (), ())) try: measured_process.MeasuredProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail('Could not construct an MeasuredProcess with empty state.')
def get_bounds(state): cast_fn = tensorflow_computation.tf_computation( lambda x: tf.cast(x, bound_dtype)) upper_bound = intrinsics.federated_map(cast_fn, process.report(state)) lower_bound = intrinsics.federated_map( tensorflow_computation.tf_computation(lambda x: x * -1.0), upper_bound) return upper_bound, lower_bound
def test_construction_with_empty_state_does_not_raise(self): initialize_fn = tensorflow_computation.tf_computation()(lambda: ()) next_fn = tensorflow_computation.tf_computation(())(lambda x: (x, 1.0)) try: iterative_process.IterativeProcess(initialize_fn, next_fn) except: # pylint: disable=bare-except self.fail( 'Could not construct an IterativeProcess with empty state.')
def update_state(state, value_min, value_max): value_min = intrinsics.federated_map( tensorflow_computation.tf_computation( lambda x: tf.cast(x, min_dtype)), value_min) value_max = intrinsics.federated_map( tensorflow_computation.tf_computation( lambda x: tf.cast(x, max_dtype)), value_max) return intrinsics.federated_zip( (upper_bound_process.next(state[0], value_max), lower_bound_process.next(state[1], value_min)))
def next_fn(state, value): state = intrinsics.federated_map( tensorflow_computation.tf_computation(lambda x: x + 1), state) result = intrinsics.federated_map( tensorflow_computation.tf_computation( lambda x: tf.nest.map_structure(lambda y: y + 1, x)), intrinsics.federated_sum(value)) measurements = intrinsics.federated_value(MEASUREMENT_CONSTANT, placements.SERVER) return measured_process.MeasuredProcessOutput( state, result, measurements)
def test_invoke_with_typed_fn(self): def foo(x): return x > 10 foo = tensorflow_computation.tf_computation(foo, tf.int32) self.assertEqual(foo.type_signature.compact_representation(), '(int32 -> bool)')
def test_invoke_with_no_arg_fn(self): def foo(): return 10 foo = tensorflow_computation.tf_computation(foo) self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)')
def one_round_computation(examples): """The TFF computation to compute the aggregated IBLT sketch.""" if secure_sum_bitwidth is not None: # Use federated secure modular sum for IBLT sketches, because IBLT # sketches are decoded by taking modulo over the field size. sketch_sum_fn = secure_modular_sum count_sum_fn = secure_sum else: sketch_sum_fn = intrinsics.federated_sum count_sum_fn = intrinsics.federated_sum round_timestamp = intrinsics.federated_eval( tensorflow_computation.tf_computation( lambda: tf.cast(tf.timestamp(), tf.int64)), placements.SERVER) clients = count_sum_fn( intrinsics.federated_value(1, placements.CLIENTS)) sketch, count_tensor = intrinsics.federated_map( compute_sketch, examples) sketch = sketch_sum_fn(sketch) count_tensor = count_sum_fn(count_tensor) (heavy_hitters, heavy_hitters_unique_counts, heavy_hitters_counts, num_not_decoded) = intrinsics.federated_map(decode_heavy_hitters, (sketch, count_tensor)) server_output = intrinsics.federated_zip( ServerOutput( clients=clients, heavy_hitters=heavy_hitters, heavy_hitters_unique_counts=heavy_hitters_unique_counts, heavy_hitters_counts=heavy_hitters_counts, num_not_decoded=num_not_decoded, round_timestamp=round_timestamp)) return server_output
def test_roundtrip(self): add = tensorflow_computation.tf_computation(lambda x, y: x + y) server_data_type = computation_types.at_server(tf.int32) client_data_type = computation_types.at_clients(tf.int32) @federated_computation.federated_computation(server_data_type, client_data_type) def add_server_number_plus_one(server_number, client_numbers): one = intrinsics.federated_value(1, placements.SERVER) server_context = intrinsics.federated_map(add, (one, server_number)) client_context = intrinsics.federated_broadcast(server_context) return intrinsics.federated_map(add, (client_context, client_numbers)) bf = form_utils.get_broadcast_form_for_computation( add_server_number_plus_one) self.assertEqual(bf.server_data_label, 'server_number') self.assertEqual(bf.client_data_label, 'client_numbers') type_test_utils.assert_types_equivalent( bf.compute_server_context.type_signature, computation_types.FunctionType(tf.int32, (tf.int32, ))) self.assertEqual(2, bf.compute_server_context(1)[0]) type_test_utils.assert_types_equivalent( bf.client_processing.type_signature, computation_types.FunctionType(((tf.int32, ), tf.int32), tf.int32)) self.assertEqual(3, bf.client_processing((1, ), 2)) round_trip_comp = form_utils.get_computation_for_broadcast_form(bf) type_test_utils.assert_types_equivalent( round_trip_comp.type_signature, add_server_number_plus_one.type_signature) # 2 (server data) + 1 (constant in comp) + 2 (client data) = 5 (output) self.assertEqual([5, 6, 7], round_trip_comp(2, [2, 3, 4]))
def init_fn(): specs = weight_tensor_specs.trainable optimizer_state = intrinsics.federated_eval( tensorflow_computation.tf_computation( lambda: optimizer.initialize(specs)), placements.SERVER) aggregator_state = full_gradient_aggregator.initialize() return intrinsics.federated_zip((optimizer_state, aggregator_state))
def test_next_return_namedtuple_raises(self): measured_process_output = collections.namedtuple( 'MeasuredProcessOutput', ['state', 'result', 'measurements']) namedtuple_next_fn = tensorflow_computation.tf_computation( tf.int32)(lambda state: measured_process_output(state, (), ())) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): measured_process.MeasuredProcess(test_initialize_fn, namedtuple_next_fn)
def test_roundtrip_no_broadcast(self): add_five = tensorflow_computation.tf_computation(lambda x: x + 5) server_data_type = computation_types.at_server(()) client_data_type = computation_types.at_clients(tf.int32) @federated_computation.federated_computation(server_data_type, client_data_type) def add_five_at_clients(naught_at_server, client_numbers): del naught_at_server return intrinsics.federated_map(add_five, client_numbers) bf = form_utils.get_broadcast_form_for_computation(add_five_at_clients) self.assertEqual(bf.server_data_label, 'naught_at_server') self.assertEqual(bf.client_data_label, 'client_numbers') type_test_utils.assert_types_equivalent( bf.compute_server_context.type_signature, computation_types.FunctionType((), ())) type_test_utils.assert_types_equivalent( bf.client_processing.type_signature, computation_types.FunctionType(((), tf.int32), tf.int32)) self.assertEqual(6, bf.client_processing((), 1)) round_trip_comp = form_utils.get_computation_for_broadcast_form(bf) type_test_utils.assert_types_equivalent( round_trip_comp.type_signature, add_five_at_clients.type_signature) self.assertEqual([10, 11, 12], round_trip_comp((), [5, 6, 7]))
def update_state(state, value_min, value_max): abs_max_fn = tensorflow_computation.tf_computation( lambda x, y: tf.cast(tf.maximum(tf.abs(x), tf.abs(y)), expected_dtype)) abs_value_max = intrinsics.federated_map(abs_max_fn, (value_min, value_max)) return process.next(state, abs_value_max)
def next_fn(strings, val): new_state_fn = tensorflow_computation.tf_computation()( lambda s: tf.concat([s, tf.constant(['abc'])], axis=0)) return MeasuredProcessOutput( intrinsics.federated_map(new_state_fn, strings), intrinsics.federated_sum(val), intrinsics.federated_value(1, placements.SERVER))
def _create_next_fn(self, inner_agg_next, state_type, value_type): modular_clip_by_value_fn = tensorflow_computation.tf_computation( _modular_clip_by_value) @federated_computation.federated_computation( state_type, computation_types.at_clients(value_type)) def next_fn(state, value): clip_lower = intrinsics.federated_value(self._clip_range_lower, placements.SERVER) clip_upper = intrinsics.federated_value(self._clip_range_upper, placements.SERVER) # Modular clip values before aggregation. clipped_value = intrinsics.federated_map( modular_clip_by_value_fn, (value, intrinsics.federated_broadcast(clip_lower), intrinsics.federated_broadcast(clip_upper))) inner_agg_output = inner_agg_next(state, clipped_value) # Clip the aggregate to the same range again (not considering summands). clipped_agg_output_result = intrinsics.federated_map( modular_clip_by_value_fn, (inner_agg_output.result, clip_lower, clip_upper)) measurements = collections.OrderedDict( modclip=inner_agg_output.measurements) return measured_process.MeasuredProcessOutput( state=inner_agg_output.state, result=clipped_agg_output_result, measurements=intrinsics.federated_zip(measurements)) return next_fn
def test_next_state_not_assignable_tuple_result(self): float_next_fn = tensorflow_computation.tf_computation( tf.float32, tf.float32)(lambda state, x: (tf.cast(state, tf.float32), x)) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(test_initialize_fn, float_next_fn, test_report_fn)
def test_map_estimate_not_assignable(self): map_fn = tensorflow_computation.tf_computation( tf.int32)(lambda estimate: estimate) process = estimation_process.EstimationProcess(test_initialize_fn, test_next_fn, test_report_fn) with self.assertRaises(estimation_process.EstimateNotAssignableError): process.map(map_fn)
def test_takes_tuple_typed(self): @tf.function def foo(t): return t[0] + t[1] foo = tensorflow_computation.tf_computation(foo, (tf.int32, tf.int32)) self.assertEqual(foo.type_signature.compact_representation(), '(<int32,int32> -> int32)')
def _run_in_tf_computation(optimizer, spec): weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) init_fn = tensorflow_computation.tf_computation( lambda: optimizer.initialize(spec)) next_fn = tensorflow_computation.tf_computation(optimizer.next) state = init_fn() state_history = [state] weights_history = [weights] for _ in range(3): state, weights = next_fn(state, weights, gradients) state_history.append(state) weights_history.append(weights) return state_history, weights_history
def next_fn(state, weights, updates): new_weights = intrinsics.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (weights.trainable, updates)) new_weights = intrinsics.federated_zip( model_utils.ModelWeights(new_weights, ())) return measured_process.MeasuredProcessOutput(state, new_weights, empty_at_server())
def test_non_federated_init_next_raises(self): initialize_fn = tensorflow_computation.tf_computation(lambda: 0) @tensorflow_computation.tf_computation(tf.int32, tf.float32) def next_fn(state, val): return MeasuredProcessOutput(state, val, ()) with self.assertRaises(aggregation_process.AggregationNotFederatedError): aggregation_process.AggregationProcess(initialize_fn, next_fn)
def next_comp(state, value): return measured_process.MeasuredProcessOutput( state=intrinsics.federated_map(_add_one, state), result=intrinsics.federated_broadcast(value), # Arbitrary metrics for testing. measurements=intrinsics.federated_map( tensorflow_computation.tf_computation( lambda v: tf.linalg.global_norm(tf.nest.flatten(v)) + 3.0), value))
def _create_test_iterative_process(state_type, state_init): @tensorflow_computation.tf_computation(state_type) def next_fn(state): return state return iterative_process.IterativeProcess( initialize_fn=tensorflow_computation.tf_computation( lambda: tf.constant(state_init)), next_fn=next_fn)
def _test_map_reduce_form_computations(): @tensorflow_computation.tf_computation def initialize(): return tf.constant(0) @tensorflow_computation.tf_computation(tf.int32) def prepare(server_state): del server_state # Unused return tf.constant(1.0) @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.float32), tf.float32) def work(client_data, client_input): del client_data # Unused del client_input # Unused return True, [], [], [] @tensorflow_computation.tf_computation def zero(): return tf.constant(0), tf.constant(0) @tensorflow_computation.tf_computation((tf.int32, tf.int32), tf.bool) def accumulate(accumulator, client_update): del accumulator # Unused del client_update # Unused return tf.constant(1), tf.constant(1) @tensorflow_computation.tf_computation((tf.int32, tf.int32), (tf.int32, tf.int32)) def merge(accumulator1, accumulator2): del accumulator1 # Unused del accumulator2 # Unused return tf.constant(1), tf.constant(1) @tensorflow_computation.tf_computation(tf.int32, tf.int32) def report(accumulator): del accumulator # Unused return tf.constant(1.0) unit_comp = tensorflow_computation.tf_computation(lambda: []) bitwidth = unit_comp max_input = unit_comp modulus = unit_comp unit_type = computation_types.to_type([]) @tensorflow_computation.tf_computation( tf.int32, (tf.float32, unit_type, unit_type, unit_type)) def update(server_state, global_update): del server_state # Unused del global_update # Unused return tf.constant(1), [] return (initialize, prepare, work, zero, accumulate, merge, report, bitwidth, max_input, modulus, update)
def test_invoke_with_polymorphic_lambda(self): foo = lambda x: x > 10 foo = tensorflow_computation.tf_computation(foo) concrete_fn = foo.fn_for_argument_type( computation_types.TensorType(tf.int32)) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(int32 -> bool)') concrete_fn = foo.fn_for_argument_type( computation_types.TensorType(tf.float32)) self.assertEqual(concrete_fn.type_signature.compact_representation(), '(float32 -> bool)')
def hierarchical_histogram_computation(federated_client_data): round_timestamp = intrinsics.federated_eval( tensorflow_computation.tf_computation( lambda: tf.cast(tf.timestamp(), tf.int64)), placements.SERVER) client_histogram = intrinsics.federated_map(client_work, federated_client_data) server_output = intrinsics.federated_zip( ServerOutput( process.next(process.initialize(), client_histogram).result, round_timestamp)) return server_output
def _compute_measurements(self, upper_bound, lower_bound, value_max, value_min): """Creates measurements to be reported. All values are summed securely.""" is_max_clipped = intrinsics.federated_map( tensorflow_computation.tf_computation( lambda bound, value: tf.cast(bound < value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(upper_bound), value_max)) max_clipped_count = intrinsics.federated_secure_sum_bitwidth( is_max_clipped, bitwidth=1) is_min_clipped = intrinsics.federated_map( tensorflow_computation.tf_computation( lambda bound, value: tf.cast(bound > value, COUNT_TF_TYPE)), (intrinsics.federated_broadcast(lower_bound), value_min)) min_clipped_count = intrinsics.federated_secure_sum_bitwidth( is_min_clipped, bitwidth=1) measurements = collections.OrderedDict( secure_upper_clipped_count=max_clipped_count, secure_lower_clipped_count=min_clipped_count, secure_upper_threshold=upper_bound, secure_lower_threshold=lower_bound) return intrinsics.federated_zip(measurements)
def test_takes_namedtuple_typed(self): MyType = collections.namedtuple('MyType', ['x', 'y']) # pylint: disable=invalid-name @tf.function def foo(x): self.assertIsInstance(x, MyType) return x.x + x.y foo = tensorflow_computation.tf_computation(foo, MyType(tf.int32, tf.int32)) self.assertEqual(foo.type_signature.compact_representation(), '(<x=int32,y=int32> -> int32)')