Example #1
0
    def test_aggregate_with_selection_from_block_by_name_results_in_single_aggregate(
            self):
        data = building_blocks.Reference(
            'a', computation_types.FederatedType(tf.int32, placements.CLIENTS))
        tup_of_data = building_blocks.Tuple([('a', data), ('b', data)])
        block_holding_tup = building_blocks.Block([], tup_of_data)
        index_0_from_block = building_blocks.Selection(
            source=block_holding_tup, name='a')
        index_1_from_block = building_blocks.Selection(
            source=block_holding_tup, name='b')

        result = building_blocks.Data('aggregation_result', tf.int32)
        zero = building_blocks.Data('zero', tf.int32)
        accumulate = building_blocks.Lambda('accumulate_param',
                                            [tf.int32, tf.int32], result)
        merge = building_blocks.Lambda('merge_param', [tf.int32, tf.int32],
                                       result)
        report = building_blocks.Lambda('report_param', tf.int32, result)

        called_intrinsic0 = building_block_factory.create_federated_aggregate(
            index_0_from_block, zero, accumulate, merge, report)
        called_intrinsic1 = building_block_factory.create_federated_aggregate(
            index_1_from_block, zero, accumulate, merge, report)
        calls = building_blocks.Tuple((called_intrinsic0, called_intrinsic1))
        comp = calls

        deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        self.assertTrue(deduped_modified)

        fed_agg = []

        def _find_called_federated_aggregate(comp):
            if (isinstance(comp, building_blocks.Call)
                    and isinstance(comp.function, building_blocks.Intrinsic)
                    and comp.function.uri
                    == intrinsic_defs.FEDERATED_AGGREGATE.uri):
                fed_agg.append(comp.function)
            return comp, False

        transformation_utils.transform_postorder(
            deduped_and_merged_comp, _find_called_federated_aggregate)
        self.assertLen(fed_agg, 1)
        self.assertEqual(
            fed_agg[0].type_signature.parameter[0].compact_representation(),
            '{<int32>}@CLIENTS')
 def federated_weighted_mean(arg):
     py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock)
     w = building_blocks.Selection(arg, index=1)
     multiplied = generic_multiply(arg)
     zip_arg = building_blocks.Struct([(None, multiplied), (None, w)])
     summed = federated_sum(
         building_block_factory.create_federated_zip(zip_arg))
     return generic_divide(summed)
Example #3
0
 def test_propogates_dependence_up_through_selection(self):
     type_signature = computation_types.StructType([tf.int32])
     whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic',
                                                  type_signature)
     selection = building_blocks.Selection(whimsy_intrinsic, index=0)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         selection, whimsy_intrinsic_predicate)
     self.assertIn(selection, dependent_nodes)
Example #4
0
 def test_returns_called_tf_computation_with_truct(self):
     constant_tuple_type = computation_types.StructType(
         [tf.int32, tf.float32])
     constant_tuple = building_block_factory.create_tensorflow_constant(
         constant_tuple_type, 1)
     sel = building_blocks.Selection(source=constant_tuple, index=0)
     tup = building_blocks.Struct([sel, sel, sel])
     self.assert_compiles_to_tensorflow(tup)
 def test_with_structure_replacing_federated_map(self):
   function_type = computation_types.FunctionType(tf.int32, tf.int32)
   tuple_ref = building_blocks.Reference('arg', [
       function_type,
       tf.int32,
   ])
   fn = building_blocks.Selection(tuple_ref, index=0)
   arg = building_blocks.Selection(tuple_ref, index=1)
   called_fn = building_blocks.Call(fn, arg)
   concrete_fn = building_blocks.Lambda(
       'x', tf.int32, building_blocks.Reference('x', tf.int32))
   concrete_arg = building_blocks.Data('a', tf.int32)
   arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg])
   generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn)
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       generated_structure)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
Example #6
0
 def _traverse_selection(comp, transform, context_tree, identifier_seq):
     """Helper function holding traversal logic for selection nodes."""
     _ = next(identifier_seq)
     source, source_modified = _transform_postorder_with_symbol_bindings_switch(
         comp.source, transform, context_tree, identifier_seq)
     if source_modified:
         comp = building_blocks.Selection(source, comp.name, comp.index)
     comp, comp_modified = transform(comp, context_tree)
     return comp, comp_modified or source_modified
Example #7
0
 def _build(comp, scope):
     """Transforms `comp` to CDF, possibly adding bindings to `scope`."""
     # The structure returned by this function is a generalized version of
     # call-dominant form. This function may result in the patterns specified in
     # the top-level function's docstring.
     if comp.is_reference():
         return scope.resolve(comp.name)
     elif comp.is_selection():
         source = _build(comp.source, scope)
         if source.is_struct():
             return source[comp.as_index()]
         return building_blocks.Selection(source, index=comp.as_index())
     elif comp.is_struct():
         elements = []
         for (name, value) in structure.iter_elements(comp):
             value = _build(value, scope)
             elements.append((name, value))
         return building_blocks.Struct(elements)
     elif comp.is_call():
         function = _build(comp.function, scope)
         argument = None if comp.argument is None else _build(
             comp.argument, scope)
         if function.is_lambda():
             if argument is not None:
                 scope = scope.new_child()
                 scope.add_local(function.parameter_name, argument)
             return _build(function.result, scope)
         else:
             return scope.create_binding(
                 building_blocks.Call(function, argument))
     elif comp.is_lambda():
         scope = scope.new_child_with_bindings()
         if comp.parameter_name:
             scope.add_local(
                 comp.parameter_name,
                 building_blocks.Reference(comp.parameter_name,
                                           comp.parameter_type))
         result = _build(comp.result, scope)
         block = scope.bindings_to_block_with_result(result)
         return building_blocks.Lambda(comp.parameter_name,
                                       comp.parameter_type, block)
     elif comp.is_block():
         scope = scope.new_child()
         for (name, value) in comp.locals:
             scope.add_local(name, _build(value, scope))
         return _build(comp.result, scope)
     elif (comp.is_intrinsic() or comp.is_data()
           or comp.is_compiled_computation()):
         _disallow_higher_order(comp, global_comp)
         return comp
     elif comp.is_placement():
         raise ValueError(
             f'Found placement {comp} in\n{global_comp}\n'
             'but placements are not allowed in local computations.')
     else:
         raise ValueError(
             f'Unrecognized computation kind\n{comp}\nin\n{global_comp}')
Example #8
0
 def _transform(comp):
     if not _should_transform(comp):
         return comp, False
     if len(intrinsics) > 1:
         index = intrinsics.index(comp)
         comp = building_blocks.Selection(ref, index=index)
         return comp, True
     else:
         return ref, True
 def test_returns_string_for_selection_with_name(self):
     ref = building_blocks.Reference('a', (('b', tf.int32), ('c', tf.bool)))
     comp = building_blocks.Selection(ref, name='b')
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, 'a.b')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, 'a.b')
     structural_string = comp.structural_representation()
     # pyformat: disable
     self.assertEqual(structural_string, 'Sel(b)\n' '|\n' 'Ref(a)')
Example #10
0
 def test_binding_multiple_args_results_in_unique_names(self):
     fed_at_clients = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     fed_at_server = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     tuple_of_federated_types = computation_types.StructType(
         [[fed_at_clients], fed_at_server, [fed_at_clients]])
     first_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=0),
                                                 index=0)
     second_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=2),
                                                  index=0)
     lam = building_blocks.Lambda(
         'x', tuple_of_federated_types,
         building_blocks.Struct([first_selection, second_selection]))
     new_lam = form_utils._as_function_of_some_federated_subparameters(
         lam, [(0, 0), (2, 0)])
     tree_analysis.check_has_unique_names(new_lam)
Example #11
0
 def test_binding_multiple_args_results_in_unique_names(self):
     fed_at_clients = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     fed_at_server = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     tuple_of_federated_types = computation_types.NamedTupleType(
         [[fed_at_clients], fed_at_server, [fed_at_clients]])
     first_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=0),
                                                 index=0)
     second_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=2),
                                                  index=0)
     lam = building_blocks.Lambda(
         'x', tuple_of_federated_types,
         building_blocks.Tuple([first_selection, second_selection]))
     deep_zeroth_index_extracted = mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
         lam, [[0, 0], [2, 0]])
     tree_analysis.check_has_unique_names(deep_zeroth_index_extracted)
    def test_returns_string_for_selection_with_index(self):
        ref = building_blocks.Reference('a', (('b', tf.int32), ('c', tf.bool)))
        comp = building_blocks.Selection(ref, index=0)

        self.assertEqual(comp.compact_representation(), 'a[0]')
        self.assertEqual(comp.formatted_representation(), 'a[0]')
        # pyformat: disable
        self.assertEqual(comp.structural_representation(), 'Sel(0)\n'
                         '|\n'
                         'Ref(a)')
Example #13
0
def _create_next_with_fake_client_output(tree):
    r"""Creates a next computation with a fake client output.

  This function returns the AST:

  Lambda
  |
  [Comp, Comp, Tuple]
               |
               []

  In the AST, `Lambda` and the first two `Comps`s in the result of `Lambda` are
  `tree` and the empty `Tuple` is the fake client output.

  This function is intended to be used by
  `get_canonical_form_for_iterative_process` to create a next computation with
  a fake client output when no client output is returned by `tree` (which
  represents the `next` function of the `tff.utils.IterativeProcess`). As a
  result, this function does not assert that there is no client output in `tree`
  and it does not assert that `tree` has the expected structure, the caller is
  expected to perform these checks before calling this function.

  Args:
    tree: An instance of `building_blocks.ComputationBuildingBlock`.

  Returns:
    A new `building_blocks.ComputationBuildingBlock` representing a next
    computaiton with a fake client output.
  """
    if isinstance(tree.result, building_blocks.Tuple):
        arg_1 = tree.result[0]
        arg_2 = tree.result[1]
    else:
        arg_1 = building_blocks.Selection(tree.result, index=0)
        arg_2 = building_blocks.Selection(tree.result, index=1)

    empty_tuple = building_blocks.Tuple([])
    client_output = building_block_factory.create_federated_value(
        empty_tuple, placements.CLIENTS)
    output = building_blocks.Tuple([arg_1, arg_2, client_output])
    return building_blocks.Lambda(tree.parameter_name, tree.parameter_type,
                                  output)
Example #14
0
 def test_returns_tf_computation_with_functional_type(self):
     param = building_blocks.Reference('x', [('a', tf.int32),
                                             ('b', tf.float32)])
     sel = building_blocks.Selection(source=param, index=0)
     tup = building_blocks.Tuple([sel, sel, sel])
     lam = building_blocks.Lambda(param.name, param.type_signature, tup)
     transformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs(
         lam)
     self.assertTrue(modified_indicator)
     self.assertIsInstance(transformed, building_blocks.CompiledComputation)
     self.assertEqual(transformed.type_signature, lam.type_signature)
    def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type_with_names(
            self):
        identity_tf_block_type = computation_types.StructType(
            [tf.int32, tf.bool])
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        tuple_ref = building_blocks.Reference('x', [('a', tf.int32),
                                                    ('b', tf.float32),
                                                    ('c', tf.bool)])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_bool = building_blocks.Selection(tuple_ref, index=2)
        created_tuple = building_blocks.Struct([selected_int, selected_bool])
        called_tf_block = building_blocks.Call(identity_tf_block,
                                               created_tuple)
        lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32),
                                                      ('b', tf.float32),
                                                      ('c', tf.bool)],
                                                called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)
        self.assertEqual(exec_lambda({
            'a': 9,
            'b': 10.,
            'c': False
        }), exec_tf({
            'a': 9,
            'b': 10.,
            'c': False
        }))
 def test_selection(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.StructType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   new_lam = form_utils._as_function_of_single_subparameter(lam, 0)
   self.assert_selected_param_to_result_type(lam, new_lam, 0)
 def test_raises_on_selections_at_different_placements(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.StructType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   with self.assertRaises(form_utils._MismatchedSelectionPlacementError):
     form_utils._as_function_of_some_federated_subparameters(lam, [(0,), (1,)])
Example #18
0
    def test_single_nested_element_selection(self):
        fed_at_clients = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        fed_at_server = computation_types.FederatedType(
            tf.int32, placements.SERVER)
        tuple_of_federated_types = computation_types.StructType(
            [[fed_at_clients], fed_at_server])
        lam = building_blocks.Lambda(
            'x', tuple_of_federated_types,
            building_blocks.Selection(building_blocks.Selection(
                building_blocks.Reference('x', tuple_of_federated_types),
                index=0),
                                      index=0))

        new_lam = form_utils._as_function_of_some_federated_subparameters(
            lam, [(0, 0)])
        expected_parameter_type = computation_types.at_clients((tf.int32, ))
        type_test_utils.assert_types_equivalent(
            new_lam.type_signature,
            computation_types.FunctionType(expected_parameter_type,
                                           lam.result.type_signature))
Example #19
0
 def test_returns_called_tf_computation_with_non_functional_type(self):
     constant_tuple = building_block_factory.create_tensorflow_constant(
         [tf.int32, tf.float32], 1)
     sel = building_blocks.Selection(source=constant_tuple, index=0)
     tup = building_blocks.Tuple([sel, sel, sel])
     transformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs(
         tup)
     self.assertTrue(modified_indicator)
     self.assertEqual(transformed.type_signature, tup.type_signature)
     self.assertIsInstance(transformed, building_blocks.Call)
     self.assertIsInstance(transformed.function,
                           building_blocks.CompiledComputation)
     self.assertIsNone(transformed.argument)
  def test_broadcast_dependent_on_aggregate_fails_well(self):
    mrf = mapreduce_test_utils.get_temperature_sensor_example()
    it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
    next_comp = it.next.to_building_block()
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Struct([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)
Example #21
0
 def test_raises_on_selections_at_different_placements(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.NamedTupleType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   with self.assertRaisesRegex(ValueError, 'at the same placement.'):
     mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
         lam, [[0], [1]])
Example #22
0
def select_output_from_lambda(comp, indices):
    """Constructs a new function with result of selecting `indices` from `comp`.

  Args:
    comp: Instance of `building_blocks.Lambda` of result type `tff.StructType`
      from which we wish to select `indices`. Notice that this named tuple type
      must have elements of federated type.
    indices: Instance of `int`, `list`, or `tuple`, specifying the indices we
      wish to select from the result of `comp`. If `indices` is an `int`, the
      result of the returned `comp` will be of type at index `indices` in
      `comp.type_signature.result`. If `indices` is a `list` or `tuple`, the
      result type will be a `tff.StructType` wrapping the specified selections.

  Returns:
    A transformed version of `comp` with result value the selection from the
    result of `comp` specified by `indices`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(comp.type_signature.result,
                            computation_types.StructType)
    py_typecheck.check_type(indices, (int, tuple, list))

    def _create_selected_output(comp, index, is_struct_opt):
        if is_struct_opt:
            return comp[index]
        else:
            return building_blocks.Selection(comp, index=index)

    result_tuple = comp.result
    tuple_opt = result_tuple.is_struct()
    elements = []
    if isinstance(indices, (tuple, list)):
        for x in indices:
            if isinstance(x, (tuple, list)):
                selected_output = result_tuple
                for y in x:
                    tuple_opt = selected_output.is_struct()
                    selected_output = _create_selected_output(
                        selected_output, y, tuple_opt)
            else:
                selected_output = _create_selected_output(
                    result_tuple, x, tuple_opt)
            elements.append(selected_output)
        result = building_blocks.Struct(elements)
    else:
        if tuple_opt:
            result = result_tuple[indices]
        else:
            result = building_blocks.Selection(result_tuple, index=indices)
    return building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                  result)
Example #23
0
  def test_binds_multiple_args_deep_in_type_tree_to_lower_lambda(self):
    fed_at_clients = computation_types.FederatedType(tf.int32,
                                                     placements.CLIENTS)
    fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
    tuple_of_federated_types = computation_types.NamedTupleType(
        [[fed_at_clients], fed_at_server, [fed_at_clients]])
    first_selection = building_blocks.Selection(
        building_blocks.Selection(
            building_blocks.Reference('x', tuple_of_federated_types), index=0),
        index=0)
    second_selection = building_blocks.Selection(
        building_blocks.Selection(
            building_blocks.Reference('x', tuple_of_federated_types), index=2),
        index=0)
    lam = building_blocks.Lambda(
        'x', tuple_of_federated_types,
        building_blocks.Tuple([first_selection, second_selection]))
    expected_fn_regex = (r'\(_([a-z]{3})2 -> <federated_map\(<\(_(\1)3 -> '
                         r'_(\1)3\[0\]\),_(\1)2>\),federated_map\(<\(_(\1)4 -> '
                         r'_(\1)4\[1\]\),_(\1)2>\)>\)')
    expected_arg_regex = r'federated_zip_at_clients\(<_([a-z]{3})1\[0\]\[0\],_(\1)1\[2\]\[0\]>\)'

    deep_zeroth_index_extracted = transformations.zip_selection_as_argument_to_lower_level_lambda(
        lam, [[0, 0], [2, 0]])

    self.assertEqual(deep_zeroth_index_extracted.type_signature,
                     lam.type_signature)
    self.assertIsInstance(deep_zeroth_index_extracted, building_blocks.Lambda)
    self.assertIsInstance(deep_zeroth_index_extracted.result,
                          building_blocks.Call)
    self.assertIsInstance(deep_zeroth_index_extracted.result.function,
                          building_blocks.Lambda)
    self.assertRegexMatch(
        deep_zeroth_index_extracted.result.function.compact_representation(),
        [expected_fn_regex])
    self.assertRegexMatch(
        deep_zeroth_index_extracted.result.argument.compact_representation(),
        [expected_arg_regex])
    def test_replaces_lambda_to_unnamed_tuple_of_called_graphs_with_tf_of_same_type(
            self):
        int_tensor_type = computation_types.TensorType(tf.int32)
        int_identity_tf_block = building_block_factory.create_compiled_identity(
            int_tensor_type)
        float_tensor_type = computation_types.TensorType(tf.float32)
        float_identity_tf_block = building_block_factory.create_compiled_identity(
            float_tensor_type)
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_float = building_blocks.Selection(tuple_ref, index=1)

        called_int_tf_block = building_blocks.Call(int_identity_tf_block,
                                                   selected_int)
        called_float_tf_block = building_blocks.Call(float_identity_tf_block,
                                                     selected_float)
        tuple_of_called_graphs = building_blocks.Struct(
            [called_int_tf_block, called_float_tf_block])
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                tuple_of_called_graphs)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)

        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)
        self.assertEqual(exec_lambda([11, 12.]), exec_tf([11, 12.]))
 def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self):
     first_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     packed_broadcast = building_blocks.Struct([
         building_blocks.Data('a', computation_types.at_server(tf.int32)),
         first_broadcast
     ])
     sel = building_blocks.Selection(packed_broadcast, index=0)
     second_broadcast = building_block_factory.create_federated_broadcast(
         sel)
     result = transformations.to_call_dominant(second_broadcast)
     comp = building_blocks.Lambda('a', tf.int32, result)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = building_blocks.Reference(arg_name, arg_type)
     arg_f = building_blocks.Selection(arg, name='f')
     arg_x = building_blocks.Selection(arg, name='x')
     x = building_blocks.Lambda(
         arg_name, arg_type,
         building_blocks.Call(arg_f, building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(x.result.compact_representation(),
                      'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(x.compact_representation(),
                      '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
 def test_single_element_selection_leaves_no_unbound_references(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.StructType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   new_lam = form_utils._as_function_of_some_federated_subparameters(
       lam, [(0,)])
   unbound_references = transformation_utils.get_map_of_unbound_references(
       new_lam)[new_lam]
   self.assertEmpty(unbound_references)
Example #28
0
 def test_binding_single_arg_leaves_no_unbound_references(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.NamedTupleType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   zeroth_index_extracted = mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda(
       lam, [[0]])
   unbound_references = transformations.get_map_of_unbound_references(
       zeroth_index_extracted)[zeroth_index_extracted]
   self.assertEmpty(unbound_references)
Example #29
0
 def __getattr__(self, name):
     py_typecheck.check_type(name, str)
     _check_struct_or_federated_struct(self, name)
     if _is_federated_named_tuple(self):
         return ValueImpl(
             building_block_factory.create_federated_getattr_call(
                 self._comp, name), self._context_stack)
     if name not in dir(self.type_signature):
         raise AttributeError(
             'There is no such attribute \'{}\' in this tuple. Valid attributes: ({})'
             .format(name, ', '.join(dir(self.type_signature))))
     if self._comp.is_struct():
         return ValueImpl(getattr(self._comp, name), self._context_stack)
     return ValueImpl(building_blocks.Selection(self._comp, name=name),
                      self._context_stack)
Example #30
0
 def _traverse_selection(comp, transform, context_tree, identifier_seq):
     """Helper function holding traversal logic for selection nodes."""
     _ = next(identifier_seq)
     source, source_modified = _transform_postorder_with_symbol_bindings_switch(
         comp.source, transform, context_tree, identifier_seq)
     if source_modified:
         # Normalize selection to index based on the type signature of the
         # original source. The new source may not have names present.
         if comp.index is not None:
             index = comp.index
         else:
             index = structure.name_to_index_map(
                 comp.source.type_signature)[comp.name]
         comp = building_blocks.Selection(source, index=index)
     comp, comp_modified = transform(comp, context_tree)
     return comp, comp_modified or source_modified