def test_passes_with_noarg_top_level_computation(self): up_to_merge = build_noarg_count_clients_computation() merge = build_whimsy_merge_computation(tf.int32) after_merge = build_whimsy_after_merge_computation( up_to_merge.type_signature.parameter, merge.type_signature.result) mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) self.assertIsInstance(mergeable_comp_form, mergeable_comp_execution_context.MergeableCompForm)
def test_passes_with_correct_signatures(self): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_whimsy_merge_computation(tf.int32) after_merge = build_whimsy_after_merge_computation( up_to_merge.type_signature.parameter, merge.type_signature.result) mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) self.assertIsInstance(mergeable_comp_form, mergeable_comp_execution_context.MergeableCompForm)
def test_raises_no_top_level_argument_in_after_agg(self): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_whimsy_merge_computation(tf.int32) @federated_computation.federated_computation( computation_types.at_server(merge.type_signature.result)) def bad_after_merge(x): return x with self.assertRaises(computation_types.TypeNotAssignableError): mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=bad_after_merge)
def test_raises_up_to_merge_returns_non_server_placed_result(self): @computations.federated_computation(computation_types.at_server(tf.int32)) def bad_up_to_merge(x): # Returns non SERVER-placed result. return x, x merge = build_whimsy_merge_computation(tf.int32) after_merge = build_whimsy_after_merge_computation( bad_up_to_merge.type_signature.parameter, merge.type_signature.result) with self.assertRaises(mergeable_comp_execution_context.UpToMergeTypeError): mergeable_comp_execution_context.MergeableCompForm( up_to_merge=bad_up_to_merge, merge=merge, after_merge=after_merge)
def test_raises_mismatched_up_to_merge_and_merge(self): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) bad_merge = build_whimsy_merge_computation(tf.float32) @federated_computation.federated_computation( up_to_merge.type_signature.parameter, computation_types.at_server(bad_merge.type_signature.result)) def after_merge(x, y): return (x, y) with self.assertRaises( mergeable_comp_execution_context.MergeTypeNotAssignableError): mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=bad_merge, after_merge=after_merge)
def test_computes_sum_of_all_values(self, arg, expected_sum): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_sum_merge_computation(tf.int32) after_merge = build_sum_merge_with_first_arg_computation( up_to_merge.type_signature.parameter, merge.type_signature.result) mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) ex_factories = [ python_executor_stacks.local_executor_factory() for _ in range(5) ] mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories) expected_result = type_conversions.type_to_py_container( expected_sum, after_merge.type_signature.result) result = mergeable_comp_context.invoke(mergeable_comp_form, arg) self.assertEqual(expected_result, result)
def test_raises_merge_computation_not_assignable_result(self): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) @tensorflow_computation.tf_computation(tf.int32, tf.int32) def bad_merge(x, y): del x, y # Unused return 1. # of type float. @federated_computation.federated_computation( up_to_merge.type_signature.parameter, computation_types.at_server(bad_merge.type_signature.result)) def after_merge(x, y): return (x, y) with self.assertRaises( mergeable_comp_execution_context.MergeTypeNotAssignableError): mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=bad_merge, after_merge=after_merge)
def test_counts_clients_with_noarg_computation(self): num_clients = 100 num_executors = 5 up_to_merge = build_noarg_count_clients_computation() merge = build_sum_merge_computation(tf.int32) after_merge = build_return_merge_result_with_no_first_arg_computation( merge.type_signature.result) mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) ex_factories = [ python_executor_stacks.local_executor_factory( default_num_clients=int(num_clients / num_executors)) for _ in range(num_executors) ] mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories) expected_result = num_clients result = mergeable_comp_context.invoke(mergeable_comp_form, None) self.assertEqual(result, expected_result)
def test_raises_with_aggregation_in_after_agg(self): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_whimsy_merge_computation(tf.int32) @federated_computation.federated_computation( up_to_merge.type_signature.parameter, computation_types.at_server(merge.type_signature.result)) def after_merge_with_sum(original_arg, merged_arg): del merged_arg # Unused # Second element in original arg is the clients-placed value. return intrinsics.federated_sum(original_arg[1]) with self.assertRaisesRegex( mergeable_comp_execution_context.AfterMergeStructureError, 'federated_sum'): mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge_with_sum)
def test_runs_computation_with_clients_placed_return_values(self, arg): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_whimsy_merge_computation(tf.int32) after_merge = build_whimsy_after_merge_computation( up_to_merge.type_signature.parameter, merge.type_signature.result) # Simply returns the original argument mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) ex_factories = [ python_executor_stacks.local_executor_factory() for _ in range(5) ] mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories) # We preemptively package as a struct to work around shortcircuiting in # type_to_py_container in a non-Struct argument case. arg = structure.Struct.unnamed(*arg) expected_result = type_conversions.type_to_py_container( arg, after_merge.type_signature.result) result = mergeable_comp_context.invoke(mergeable_comp_form, arg) self.assertEqual(result, expected_result)
def compile_to_mergeable_comp_form( comp: computation_impl.ConcreteComputation ) -> mergeable_comp_execution_context.MergeableCompForm: """Compiles a computation with a single aggregation to `MergeableCompForm`. Compilation proceeds by splitting on the lone aggregation, and using the aggregation's internal functions to generate a semantically equivalent instance of `mergeable_comp_execution_context.MergeableCompForm`. Args: comp: Instance of `computation_impl.ConcreteComputation` to compile. Assumed to be representable as a computation with a single aggregation in its body, so that for example two parallel aggregations are allowed, but multiple dependent aggregations are disallowed. Additionally assumed to be of functional type. Returns: A semantically equivalent instance of `mergeable_comp_execution_context.MergeableCompForm`. Raises: TypeError: If `comp` is not a building block, or is not of functional TFF type. ValueError: If `comp` cannot be represented as a computation with at most one aggregation in its body. """ original_return_type = comp.type_signature.result building_block = comp.to_building_block() lam = _ensure_lambda(building_block) lowered_bb, _ = tree_transformations.replace_intrinsics_with_bodies(lam) # We transform the body of this computation to easily preserve the top-level # lambda required by force-aligning. call_dominant_body_bb = transformations.to_call_dominant(lowered_bb.result) call_dominant_bb = building_blocks.Lambda(lowered_bb.parameter_name, lowered_bb.parameter_type, call_dominant_body_bb) # This check should not throw false positives because we just ensured we are # in call-dominant form. tree_analysis.check_aggregate_not_dependent_on_aggregate(call_dominant_bb) before_agg, after_agg = transformations.force_align_and_split_by_intrinsics( call_dominant_bb, [building_block_factory.create_null_federated_aggregate()]) # Construct a report function which accepts the result of merge. merge_fn_type = before_agg.type_signature.result[ 'federated_aggregate_param'][3] identity_report = computation_impl.ConcreteComputation.from_building_block( building_block_factory.create_compiled_identity(merge_fn_type.result)) zero_comp, accumulate_comp, merge_comp, report_comp = _extract_federated_aggregate_computations( before_agg) before_agg_callable = computation_impl.ConcreteComputation.from_building_block( before_agg) after_agg_callable = computation_impl.ConcreteComputation.from_building_block( after_agg) if before_agg.type_signature.parameter is not None: # TODO(b/147499373): If None-arguments were uniformly represented as empty # tuples, we would be able to avoid this (and related) ugly casing. @federated_computation.federated_computation( before_agg.type_signature.parameter) def up_to_merge_computation(arg): federated_aggregate_args = before_agg_callable( arg)['federated_aggregate_param'] value_to_aggregate = federated_aggregate_args[0] zero = zero_comp() return intrinsics.federated_aggregate(value_to_aggregate, zero, accumulate_comp, merge_comp, identity_report) @federated_computation.federated_computation( before_agg.type_signature.parameter, computation_types.at_server(identity_report.type_signature.result)) def after_merge_computation(top_level_arg, merge_result): reported_result = intrinsics.federated_map(report_comp, merge_result) return after_agg_callable(top_level_arg, [reported_result]) else: @federated_computation.federated_computation() def up_to_merge_computation(): federated_aggregate_args = before_agg_callable( )['federated_aggregate_param'] value_to_aggregate = federated_aggregate_args[0] zero = zero_comp() return intrinsics.federated_aggregate(value_to_aggregate, zero, accumulate_comp, merge_comp, identity_report) @federated_computation.federated_computation( computation_types.at_server(identity_report.type_signature.result)) def after_merge_computation(merge_result): reported_result = intrinsics.federated_map(report_comp, merge_result) return after_agg_callable([[reported_result]]) annotated_type_signature = computation_types.FunctionType( after_merge_computation.type_signature.parameter, original_return_type) after_merge_computation = computation_impl.ConcreteComputation.with_type( after_merge_computation, annotated_type_signature) return mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge_computation, merge=merge_comp, after_merge=after_merge_computation)