def test_replaces_lambda_to_called_composition_of_tf_blocks_with_tf_of_same_type_named_param(
      self):
    selection_type = computation_types.StructType([('a', tf.int32),
                                                   ('b', tf.float32)])
    selection_tf_block = _create_compiled_computation(lambda x: x[0],
                                                      selection_type)
    add_one_int_type = computation_types.TensorType(tf.int32)
    add_one_int_tf_block = _create_compiled_computation(lambda x: x + 1,
                                                        add_one_int_type)
    int_ref = building_blocks.Reference('x', [('a', tf.int32),
                                              ('b', tf.float32)])
    called_selection = building_blocks.Call(selection_tf_block, int_ref)
    one_added = building_blocks.Call(add_one_int_tf_block, called_selection)
    lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32),
                                                  ('b', tf.float32)], one_added)

    parsed, modified = parse_tff_to_tf(lambda_wrapper)
    exec_lambda = computation_wrapper_instances.building_block_to_computation(
        lambda_wrapper)
    exec_tf = computation_wrapper_instances.building_block_to_computation(
        parsed)

    self.assertIsInstance(parsed, building_blocks.CompiledComputation)
    self.assertTrue(modified)
    # TODO(b/157172423): change to assertEqual when Py container is preserved.
    parsed.type_signature.check_equivalent_to(lambda_wrapper.type_signature)

    self.assertEqual(
        exec_lambda({
            'a': 15,
            'b': 16.
        }), exec_tf({
            'a': 15,
            'b': 16.
        }))
예제 #2
0
    def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type(
            self):
        identity_tf_block = building_block_factory.create_compiled_identity(
            [tf.int32, tf.bool])
        tuple_ref = building_blocks.Reference('x',
                                              [tf.int32, tf.float32, tf.bool])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_bool = building_blocks.Selection(tuple_ref, index=2)
        created_tuple = building_blocks.Tuple([selected_int, selected_bool])
        called_tf_block = building_blocks.Call(identity_tf_block,
                                               created_tuple)
        lambda_wrapper = building_blocks.Lambda(
            'x', [tf.int32, tf.float32, tf.bool], called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)
        self.assertEqual(exec_lambda([7, 8., True]), exec_tf([7, 8., True]))
예제 #3
0
    def test_replaces_lambda_to_called_graph_on_selection_from_arg_with_tf_of_same_type_with_names(
            self):
        identity_tf_block = building_block_factory.create_compiled_identity(
            tf.int32)
        tuple_ref = building_blocks.Reference('x', [('a', tf.int32),
                                                    ('b', tf.float32)])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        called_tf_block = building_blocks.Call(identity_tf_block, selected_int)
        lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32),
                                                      ('b', tf.float32)],
                                                called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda({
            'a': 5,
            'b': 6.
        }), exec_tf({
            'a': 5,
            'b': 6.
        }))
예제 #4
0
    def test_replaces_lambda_to_called_composition_of_tf_blocks_with_tf_of_same_type_named_param(
            self):
        selection_tf_block = _create_compiled_computation(
            lambda x: x[0], [('a', tf.int32), ('b', tf.float32)])
        add_one_int_tf_block = _create_compiled_computation(
            lambda x: x + 1, tf.int32)
        int_ref = building_blocks.Reference('x', [('a', tf.int32),
                                                  ('b', tf.float32)])
        called_selection = building_blocks.Call(selection_tf_block, int_ref)
        one_added = building_blocks.Call(add_one_int_tf_block,
                                         called_selection)
        lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32),
                                                      ('b', tf.float32)],
                                                one_added)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda({
            'a': 15,
            'b': 16.
        }), exec_tf({
            'a': 15,
            'b': 16.
        }))
예제 #5
0
    def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type(
            self):
        int_identity_tf_block = building_block_factory.create_compiled_identity(
            tf.int32)
        float_identity_tf_block = building_block_factory.create_compiled_identity(
            tf.float32)
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_float = building_blocks.Selection(tuple_ref, index=1)

        called_int_tf_block = building_blocks.Call(int_identity_tf_block,
                                                   selected_int)
        called_float_tf_block = building_blocks.Call(float_identity_tf_block,
                                                     selected_float)
        tuple_of_called_graphs = building_blocks.Tuple([
            ('a', called_int_tf_block), ('b', called_float_tf_block)
        ])
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                tuple_of_called_graphs)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
    def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type(
            self):
        int_tensor_type = computation_types.TensorType(tf.int32)
        int_identity_tf_block = building_block_factory.create_compiled_identity(
            int_tensor_type)
        float_tensor_type = computation_types.TensorType(tf.float32)
        float_identity_tf_block = building_block_factory.create_compiled_identity(
            float_tensor_type)
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_float = building_blocks.Selection(tuple_ref, index=1)

        called_int_tf_block = building_blocks.Call(int_identity_tf_block,
                                                   selected_int)
        called_float_tf_block = building_blocks.Call(float_identity_tf_block,
                                                     selected_float)
        tuple_of_called_graphs = building_blocks.Struct([
            ('a', called_int_tf_block), ('b', called_float_tf_block)
        ])
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                tuple_of_called_graphs)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
    def test_replaces_lambda_to_selection_from_called_graph_with_tf_of_same_type(
            self):
        identity_tf_block_type = computation_types.StructType(
            [tf.int32, tf.float32])
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        called_tf_block = building_blocks.Call(identity_tf_block, tuple_ref)
        selection_from_call = building_blocks.Selection(called_tf_block,
                                                        index=1)
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                selection_from_call)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda([0, 1.]), exec_tf([0, 1.]))
예제 #8
0
    def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type(
            self):
        identity_tf_block_type = computation_types.StructType(
            [tf.int32, tf.bool])
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        tuple_ref = building_blocks.Reference('x',
                                              [tf.int32, tf.float32, tf.bool])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_bool = building_blocks.Selection(tuple_ref, index=2)
        created_tuple = building_blocks.Struct([selected_int, selected_bool])
        called_tf_block = building_blocks.Call(identity_tf_block,
                                               created_tuple)
        lambda_wrapper = building_blocks.Lambda(
            'x', [tf.int32, tf.float32, tf.bool], called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # FIXME(b/157172423) change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)
        self.assertEqual(exec_lambda([7, 8., True]), exec_tf([7, 8., True]))
예제 #9
0
 def test_reduces_unplaced_lambda_to_equivalent_tf(self):
   lam = building_blocks.Lambda('x', tf.int32,
                                building_blocks.Reference('x', tf.int32))
   extracted_tf = transformations.consolidate_and_extract_local_processing(lam)
   executable_tf = computation_wrapper_instances.building_block_to_computation(
       extracted_tf)
   executable_lam = computation_wrapper_instances.building_block_to_computation(
       lam)
   for k in range(10):
     self.assertEqual(executable_tf(k), executable_lam(k))
예제 #10
0
 def test_reduces_federated_apply_to_equivalent_function(self):
   lam = building_blocks.Lambda('x', tf.int32,
                                building_blocks.Reference('x', tf.int32))
   arg = building_blocks.Reference(
       'arg', computation_types.FederatedType(tf.int32, placements.CLIENTS))
   mapped_fn = building_block_factory.create_federated_map_or_apply(lam, arg)
   extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing(
       mapped_fn)
   self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation)
   executable_tf = computation_wrapper_instances.building_block_to_computation(
       extracted_tf)
   executable_lam = computation_wrapper_instances.building_block_to_computation(
       lam)
   for k in range(10):
     self.assertEqual(executable_tf(k), executable_lam(k))
예제 #11
0
def transform_to_native_form(
        comp: computation_base.Computation) -> computation_base.Computation:
    """Compiles a computation for execution in the TFF native runtime.

  This function transforms the proto underlying `comp` by first transforming it
  to call-dominant form (see `tff.framework.transform_to_call_dominant` for
  definition), then computing information on the dependency structure of the
  bindings and remapping them into tuples, such that every computation is
  evaluated as early as possible, and parallelized with any other computation
  with which it shares dependency structure.

  Args:
    comp: Instance of `computation_base.Computation` to compile.

  Returns:
    A new `computation_base.Computation` representing the compiled version of
    `comp`.
  """
    proto = computation_impl.ComputationImpl.get_proto(comp)
    computation_building_block = building_blocks.ComputationBuildingBlock.from_proto(
        proto)
    try:
        logging.debug('Compiling TFF computation.')
        call_dominant_form, _ = transformations.transform_to_call_dominant(
            computation_building_block)
        logging.debug('Computation compiled to:')
        logging.debug(call_dominant_form.formatted_representation())
        return computation_wrapper_instances.building_block_to_computation(
            call_dominant_form)
    except ValueError as e:
        logging.debug('Compilation for native runtime failed with error %s', e)
        logging.debug('computation: %s',
                      computation_building_block.compact_representation())
        return comp
 def compiler(comp):
     # Compile secure_sum and secure_sum_bitwidth intrinsics to insecure
     # TensorFlow computations for testing purposes.
     replaced_intrinsic_bodies, _ = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies(
         comp.to_building_block())
     return computation_wrapper_instances.building_block_to_computation(
         replaced_intrinsic_bodies)
예제 #13
0
def transform_mathematical_functions_to_tensorflow(
    comp: computation_base.Computation, ) -> computation_base.Computation:
    """Compiles all mathematical functions in `comp` to TensorFlow blocks.

  Notice that this does not necessarily represent a strict performance
  improvement. In particular, this compilation will not attempt to deduplicate
  across the boundaries of communication operators, and therefore it may be
  the case that compiling eagerly to TensorFlow hides the opportunity for
  a dynamic cache to be used.

  Args:
    comp: Instance of `computation_base.Computation` to compile.

  Returns:
    A new `computation_base.Computation` representing the compiled version of
    `comp`.
  """
    proto = computation_impl.ComputationImpl.get_proto(comp)
    computation_building_block = building_blocks.ComputationBuildingBlock.from_proto(
        proto)
    try:
        logging.debug('Compiling local computations to TensorFlow.')
        tf_compiled, _ = transformations.compile_local_computation_to_tensorflow(
            computation_building_block)
        logging.debug('Local computations compiled to TF:')
        logging.debug(tf_compiled.formatted_representation())
        return computation_wrapper_instances.building_block_to_computation(
            tf_compiled)
    except ValueError as e:
        logging.debug(
            'Compilation of local computation to TensorFlow failed with error %s',
            e)
        logging.debug('computation: %s',
                      computation_building_block.compact_representation())
        return comp
 def test_identity_lambda_executes_as_identity(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     computation_impl_lambda = computation_wrapper_instances.building_block_to_computation(
         lam)
     for k in range(10):
         self.assertEqual(computation_impl_lambda(k), k)
예제 #15
0
def transform_to_native_form(
        comp: computation_base.Computation) -> computation_base.Computation:
    """Compiles a computation for execution in the TFF native runtime.

  This function transforms the proto underlying `comp` by transforming it
  to call-dominant form (see `tff.framework.transform_to_call_dominant` for
  definition).

  Args:
    comp: Instance of `computation_base.Computation` to compile.

  Returns:
    A new `computation_base.Computation` representing the compiled version of
    `comp`.
  """
    proto = computation_impl.ComputationImpl.get_proto(comp)
    computation_building_block = building_blocks.ComputationBuildingBlock.from_proto(
        proto)
    try:
        logging.debug('Compiling TFF computation.')
        call_dominant_form, _ = transformations.transform_to_call_dominant(
            computation_building_block)
        logging.debug('Computation compiled to:')
        logging.debug(call_dominant_form.formatted_representation())
        return computation_wrapper_instances.building_block_to_computation(
            call_dominant_form)
    except ValueError as e:
        logging.debug('Compilation for native runtime failed with error %s', e)
        logging.debug('computation: %s',
                      computation_building_block.compact_representation())
        return comp
예제 #16
0
 def test_reduces_federated_value_at_server_to_equivalent_noarg_function(self):
   federated_value = intrinsics.federated_value(0, placements.SERVER)._comp
   extracted_tf = mapreduce_transformations.consolidate_and_extract_local_processing(
       federated_value)
   executable_tf = computation_wrapper_instances.building_block_to_computation(
       extracted_tf)
   self.assertEqual(executable_tf(), 0)
 def test_converts_building_block_to_computation(self):
     lam = building_blocks.Lambda('x', tf.int32,
                                  building_blocks.Reference('x', tf.int32))
     computation_impl_lambda = computation_wrapper_instances.building_block_to_computation(
         lam)
     self.assertIsInstance(computation_impl_lambda,
                           computation_impl.ComputationImpl)
예제 #18
0
def get_broadcast_form_for_computation(
    comp: computation_base.Computation,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG
) -> forms.BroadcastForm:
    """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation.

  Args:
    comp: An instance of `tff.Computation` that is compatible with broadcast
      form. Computations are only compatible if they take in a single value
      placed at server, return a single value placed at clients, and do not
      contain any aggregations.
    grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure
      Grappler graph optimization of the Tensorflow graphs backing the resulting
      `tff.backends.mapreduce.BroadcastForm`. These options are combined with a
      set of defaults that aggressively configure Grappler. If
      `grappler_config_proto` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the
    provided `tff.Computation`.
  """
    py_typecheck.check_type(comp, computation_base.Computation)
    _check_function_signature_compatible_with_broadcast_form(
        comp.type_signature)
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    bb = comp.to_building_block()
    bb, _ = intrinsic_reductions.replace_intrinsics_with_bodies(bb)
    bb = _replace_lambda_body_with_call_dominant_form(bb)

    tree_analysis.check_contains_only_reducible_intrinsics(bb)
    aggregations = tree_analysis.find_aggregations_in_tree(bb)
    if aggregations:
        raise ValueError(
            f'`get_broadcast_form_for_computation` called with computation '
            f'containing {len(aggregations)} aggregations, but broadcast form '
            'does not allow aggregation. Full list of aggregations:\n{aggregations}'
        )

    before_broadcast, after_broadcast = _split_ast_on_broadcast(bb)
    compute_server_context = _extract_compute_server_context(
        before_broadcast, grappler_config)
    client_processing = _extract_client_processing(after_broadcast,
                                                   grappler_config)

    compute_server_context, client_processing = (
        computation_wrapper_instances.building_block_to_computation(bb)
        for bb in (compute_server_context, client_processing))

    comp_param_names = structure.name_list_with_nones(
        comp.type_signature.parameter)
    server_data_label, client_data_label = comp_param_names
    return forms.BroadcastForm(compute_server_context,
                               client_processing,
                               server_data_label=server_data_label,
                               client_data_label=client_data_label)
예제 #19
0
def get_map_reduce_form_for_iterative_process(
    ip: iterative_process.IterativeProcess,
    grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG
) -> forms.MapReduceForm:
    """Constructs `tff.backends.mapreduce.MapReduceForm` given iterative process.

  Args:
    ip: An instance of `tff.templates.IterativeProcess` that is compatible with
      MapReduce form. Iterative processes are only compatible if `initialize_fn`
      returns a single federated value placed at `SERVER` and `next` takes
      exactly two arguments. The first must be the state value placed at
      `SERVER`. - `next` returns exactly two values.
    grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to
      configure Grappler graph optimization of the TensorFlow graphs backing the
      resulting `tff.backends.mapreduce.MapReduceForm`. These options are
      combined with a set of defaults that aggressively configure Grappler. If
      the input `grappler_config` has
      `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is
      bypassed.

  Returns:
    An instance of `tff.backends.mapreduce.MapReduceForm` equivalent to the
    provided `tff.templates.IterativeProcess`.

  Raises:
    TypeError: If the arguments are of the wrong types.
    transformations.MapReduceFormCompilationError: If the compilation
      process fails.
  """
    py_typecheck.check_type(ip, iterative_process.IterativeProcess)
    initialize_bb, next_bb = (
        check_iterative_process_compatible_with_map_reduce_form(ip))
    py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto)
    grappler_config = _merge_grappler_config_with_default(grappler_config)

    next_bb, _ = tree_transformations.uniquify_reference_names(next_bb)
    before_broadcast, after_broadcast = _split_ast_on_broadcast(next_bb)
    before_aggregate, after_aggregate = _split_ast_on_aggregate(
        after_broadcast)

    initialize = transformations.consolidate_and_extract_local_processing(
        initialize_bb, grappler_config)
    prepare = _extract_prepare(before_broadcast, grappler_config)
    work = _extract_work(before_aggregate, grappler_config)
    zero, accumulate, merge, report = _extract_federated_aggregate_functions(
        before_aggregate, grappler_config)
    bitwidth = _extract_federated_secure_sum_bitwidth_functions(
        before_aggregate, grappler_config)
    update = _extract_update(after_aggregate, grappler_config)

    next_parameter_names = structure.name_list_with_nones(
        ip.next.type_signature.parameter)
    server_state_label, client_data_label = next_parameter_names
    comps = (computation_wrapper_instances.building_block_to_computation(bb)
             for bb in (initialize, prepare, work, zero, accumulate, merge,
                        report, bitwidth, update))
    return forms.MapReduceForm(*comps,
                               server_state_label=server_state_label,
                               client_data_label=client_data_label)
예제 #20
0
    def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self):
        identity_tf_block = building_block_factory.create_compiled_identity(
            tf.int32)
        int_ref = building_blocks.Reference('x', tf.int32)
        called_tf_block = building_blocks.Call(identity_tf_block, int_ref)
        lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda(2), exec_tf(2))
예제 #21
0
    def test_replaces_lambda_to_called_tf_block_with_replicated_lambda_arg_with_tf_block_of_same_type(
            self):
        sum_and_add_one = _create_compiled_computation(
            lambda x: x[0] + x[1] + 1, [tf.int32, tf.int32])
        int_ref = building_blocks.Reference('x', tf.int32)
        tuple_of_ints = building_blocks.Tuple((int_ref, int_ref))
        summed = building_blocks.Call(sum_and_add_one, tuple_of_ints)
        lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda(17), exec_tf(17))
  def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self):
    identity_tf_block_type = computation_types.TensorType(tf.int32)
    identity_tf_block = building_block_factory.create_compiled_identity(
        identity_tf_block_type)
    int_ref = building_blocks.Reference('x', tf.int32)
    called_tf_block = building_blocks.Call(identity_tf_block, int_ref)
    lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block)

    parsed, modified = parse_tff_to_tf(lambda_wrapper)
    exec_lambda = computation_wrapper_instances.building_block_to_computation(
        lambda_wrapper)
    exec_tf = computation_wrapper_instances.building_block_to_computation(
        parsed)

    self.assertIsInstance(parsed, building_blocks.CompiledComputation)
    self.assertTrue(modified)
    # TODO(b/157172423): change to assertEqual when Py container is preserved.
    parsed.type_signature.check_equivalent_to(lambda_wrapper.type_signature)
    self.assertEqual(exec_lambda(2), exec_tf(2))
예제 #23
0
 def test_reduces_federated_value_at_server_to_equivalent_noarg_function(self):
   zero = building_block_factory.create_tensorflow_constant(
       computation_types.TensorType(tf.int32, shape=[]), 0)
   federated_value = building_block_factory.create_federated_value(
       zero, placements.SERVER)
   extracted_tf = transformations.consolidate_and_extract_local_processing(
       federated_value, DEFAULT_GRAPPLER_CONFIG)
   executable_tf = computation_wrapper_instances.building_block_to_computation(
       extracted_tf)
   self.assertEqual(executable_tf(), 0)
 def test_next_computation_returning_tensor_fails_well(self):
   cf = test_utils.get_temperature_sensor_example()
   it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   init_result = it.initialize.type_signature.result
   lam = building_blocks.Lambda('x', init_result,
                                building_blocks.Reference('x', init_result))
   bad_it = iterative_process.IterativeProcess(
       it.initialize,
       computation_wrapper_instances.building_block_to_computation(lam))
   with self.assertRaises(TypeError):
     canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
예제 #25
0
    def test_replaces_lambda_to_selection_from_called_graph_with_tf_of_same_type(
            self):
        identity_tf_block = building_block_factory.create_compiled_identity(
            [tf.int32, tf.float32])
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        called_tf_block = building_blocks.Call(identity_tf_block, tuple_ref)
        selection_from_call = building_blocks.Selection(called_tf_block,
                                                        index=1)
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                selection_from_call)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda([0, 1.]), exec_tf([0, 1.]))
  def test_replaces_lambda_to_called_tf_block_with_replicated_lambda_arg_with_tf_block_of_same_type(
      self):
    sum_and_add_one_type = computation_types.StructType([tf.int32, tf.int32])
    sum_and_add_one = _create_compiled_computation(lambda x: x[0] + x[1] + 1,
                                                   sum_and_add_one_type)
    int_ref = building_blocks.Reference('x', tf.int32)
    tuple_of_ints = building_blocks.Struct((int_ref, int_ref))
    summed = building_blocks.Call(sum_and_add_one, tuple_of_ints)
    lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed)

    parsed, modified = parse_tff_to_tf(lambda_wrapper)
    exec_lambda = computation_wrapper_instances.building_block_to_computation(
        lambda_wrapper)
    exec_tf = computation_wrapper_instances.building_block_to_computation(
        parsed)

    self.assertIsInstance(parsed, building_blocks.CompiledComputation)
    self.assertTrue(modified)
    # TODO(b/157172423): change to assertEqual when Py container is preserved.
    parsed.type_signature.check_equivalent_to(lambda_wrapper.type_signature)

    self.assertEqual(exec_lambda(17), exec_tf(17))
예제 #27
0
def deserialize_computation(
        computation_proto: pb.Computation) -> computation_base.Computation:
    """Deserializes 'tff.Computation' as a pb.Computation.

  Args:
    computation_proto: An instance of `pb.Computation`.

  Returns:
    The corresponding instance of `tff.Computation`.

  Raises:
    TypeError: If the argument is of the wrong type.
  """
    py_typecheck.check_type(computation_proto, pb.Computation)
    return computation_wrapper_instances.building_block_to_computation(
        building_blocks.ComputationBuildingBlock.from_proto(computation_proto))
예제 #28
0
def _do_not_use_transform_to_native_form(comp):
    """Use `tff.backends.native.transform_to_native_form`."""
    proto = computation_impl.ComputationImpl.get_proto(comp)
    computation_building_block = building_blocks.ComputationBuildingBlock.from_proto(
        proto)
    try:
        logging.debug('Compiling TFF computation.')
        call_dominant_form, _ = transformations.transform_to_call_dominant(
            computation_building_block)
        logging.debug('Computation compiled to:')
        logging.debug(call_dominant_form.formatted_representation())
        return computation_wrapper_instances.building_block_to_computation(
            call_dominant_form)
    except ValueError as e:
        logging.debug('Compilation for native runtime failed with error %s', e)
        logging.debug('computation: %s',
                      computation_building_block.compact_representation())
        return comp
예제 #29
0
def transform_to_native_form(
        comp: computation_base.Computation,
        transform_math_to_tf: bool = False) -> computation_base.Computation:
    """Compiles a computation for execution in the TFF native runtime.

  This function transforms the proto underlying `comp` by transforming it
  to call-dominant form (see `tff.framework.transform_to_call_dominant` for
  definition).

  Args:
    comp: Instance of `computation_base.Computation` to compile.
    transform_math_to_tf: Whether to additional transform math to TensorFlow
      graphs. Necessary if running on a execution state without
      ReferenceResolvingExecutors underneath FederatingExecutors.

  Returns:
    A new `computation_base.Computation` representing the compiled version of
    `comp`.
  """
    proto = computation_impl.ComputationImpl.get_proto(comp)
    computation_building_block = building_blocks.ComputationBuildingBlock.from_proto(
        proto)
    try:
        logging.debug('Compiling TFF computation to CDF.')
        call_dominant_form, _ = transformations.transform_to_call_dominant(
            computation_building_block)
        logging.debug('Computation compiled to:')
        logging.debug(call_dominant_form.formatted_representation())
        if transform_math_to_tf:
            logging.debug('Compiling local computations to TensorFlow.')
            call_dominant_form, _ = transformations.compile_local_computation_to_tensorflow(
                call_dominant_form)
            logging.debug('Computation compiled to:')
            logging.debug(call_dominant_form.formatted_representation())
        call_dominant_form, _ = tree_transformations.transform_tf_call_ops_to_disable_grappler(
            call_dominant_form)
        return computation_wrapper_instances.building_block_to_computation(
            call_dominant_form)
    except ValueError as e:
        logging.debug('Compilation for native runtime failed with error %s', e)
        logging.debug('computation: %s',
                      computation_building_block.compact_representation())
        return comp
예제 #30
0
  def test_broadcast_dependent_on_aggregate_fails_well(self):
    mrf = mapreduce_test_utils.get_temperature_sensor_example()
    it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
    next_comp = it.next.to_building_block()
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Struct([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)