Esempio n. 1
0
 def test_init_does_not_raise_type_error(self):
     (compute_server_context,
      client_processing) = _test_broadcast_form_computations()
     try:
         forms.BroadcastForm(compute_server_context, client_processing)
     except TypeError:
         self.fail('Raised TypeError unexpectedly.')
Esempio n. 2
0
def get_broadcast_form_for_computation(
    comp: computation_base.Computation,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG
) -> forms.BroadcastForm:
    """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation.

  Args:
    comp: An instance of `tff.Computation` that is compatible with broadcast
      form. Computations are only compatible if they take in a single value
      placed at server, return a single value placed at clients, and do not
      contain any aggregations.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization of the Tensorflow graphs backing the resulting
      `tff.backends.mapreduce.BroadcastForm`. These options are combined with a
      set of defaults that aggressively configure Grappler. If
      `grappler_config_proto` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the
    provided `tff.Computation`.
  """
    py_typecheck.check_type(comp, computation_base.Computation)
    _check_function_signature_compatible_with_broadcast_form(
        comp.type_signature)
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    bb = comp.to_building_block()
    bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb)
    bb = _replace_lambda_body_with_call_dominant_form(bb)

    tree_analysis.check_contains_only_reducible_intrinsics(bb)
    aggregations = tree_analysis.find_aggregations_in_tree(bb)
    if aggregations:
        raise ValueError(
            f'`get_broadcast_form_for_computation` called with computation '
            f'containing {len(aggregations)} aggregations, but broadcast form '
            'does not allow aggregation. Full list of aggregations:\n{aggregations}'
        )

    before_broadcast, after_broadcast = _split_ast_on_broadcast(bb)
    compute_server_context = _extract_compute_server_context(
        before_broadcast, grappler_config)
    client_processing = _extract_client_processing(after_broadcast,
                                                   grappler_config)

    compute_server_context, client_processing = (
        computation_wrapper_instances.building_block_to_computation(bb)
        for bb in (compute_server_context, client_processing))

    comp_param_names = structure.name_list_with_nones(
        comp.type_signature.parameter)
    server_data_label, client_data_label = comp_param_names
    return forms.BroadcastForm(compute_server_context,
                               client_processing,
                               server_data_label=server_data_label,
                               client_data_label=client_data_label)
Esempio n. 3
0
    def test_raises_type_error_with_mismatched_context_type(self):
        @tensorflow_computation.tf_computation(tf.int32)
        def compute_server_context(x):
            return x

        # Note: `tf.float32` here is mismatched with the context type `tf.int32`
        # returned above.
        @tensorflow_computation.tf_computation(tf.float32, tf.int32)
        def client_processing(context, client_data):
            del context
            del client_data
            return 'some string output on the clients'

        with self.assertRaises(TypeError):
            forms.BroadcastForm(compute_server_context, client_processing)