def test_init_does_not_raise_type_error(self): (compute_server_context, client_processing) = _test_broadcast_form_computations() try: forms.BroadcastForm(compute_server_context, client_processing) except TypeError: self.fail('Raised TypeError unexpectedly.')
def get_broadcast_form_for_computation( comp: computation_base.Computation, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG ) -> forms.BroadcastForm: """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation. Args: comp: An instance of `tff.Computation` that is compatible with broadcast form. Computations are only compatible if they take in a single value placed at server, return a single value placed at clients, and do not contain any aggregations. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the Tensorflow graphs backing the resulting `tff.backends.mapreduce.BroadcastForm`. These options are combined with a set of defaults that aggressively configure Grappler. If `grappler_config_proto` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the provided `tff.Computation`. """ py_typecheck.check_type(comp, computation_base.Computation) _check_function_signature_compatible_with_broadcast_form( comp.type_signature) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) bb = comp.to_building_block() bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb) bb = _replace_lambda_body_with_call_dominant_form(bb) tree_analysis.check_contains_only_reducible_intrinsics(bb) aggregations = tree_analysis.find_aggregations_in_tree(bb) if aggregations: raise ValueError( f'`get_broadcast_form_for_computation` called with computation ' f'containing {len(aggregations)} aggregations, but broadcast form ' 'does not allow aggregation. Full list of aggregations:\n{aggregations}' ) before_broadcast, after_broadcast = _split_ast_on_broadcast(bb) compute_server_context = _extract_compute_server_context( before_broadcast, grappler_config) client_processing = _extract_client_processing(after_broadcast, grappler_config) compute_server_context, client_processing = ( computation_wrapper_instances.building_block_to_computation(bb) for bb in (compute_server_context, client_processing)) comp_param_names = structure.name_list_with_nones( comp.type_signature.parameter) server_data_label, client_data_label = comp_param_names return forms.BroadcastForm(compute_server_context, client_processing, server_data_label=server_data_label, client_data_label=client_data_label)
def test_raises_type_error_with_mismatched_context_type(self): @tensorflow_computation.tf_computation(tf.int32) def compute_server_context(x): return x # Note: `tf.float32` here is mismatched with the context type `tf.int32` # returned above. @tensorflow_computation.tf_computation(tf.float32, tf.int32) def client_processing(context, client_data): del context del client_data return 'some string output on the clients' with self.assertRaises(TypeError): forms.BroadcastForm(compute_server_context, client_processing)