def test_aggregate_with_selection_from_block_by_name_results_in_single_aggregate( self): data = building_blocks.Reference( 'a', computation_types.FederatedType(tf.int32, placements.CLIENTS)) tup_of_data = building_blocks.Tuple([('a', data), ('b', data)]) block_holding_tup = building_blocks.Block([], tup_of_data) index_0_from_block = building_blocks.Selection( source=block_holding_tup, name='a') index_1_from_block = building_blocks.Selection( source=block_holding_tup, name='b') result = building_blocks.Data('aggregation_result', tf.int32) zero = building_blocks.Data('zero', tf.int32) accumulate = building_blocks.Lambda('accumulate_param', [tf.int32, tf.int32], result) merge = building_blocks.Lambda('merge_param', [tf.int32, tf.int32], result) report = building_blocks.Lambda('report_param', tf.int32, result) called_intrinsic0 = building_block_factory.create_federated_aggregate( index_0_from_block, zero, accumulate, merge, report) called_intrinsic1 = building_block_factory.create_federated_aggregate( index_1_from_block, zero, accumulate, merge, report) calls = building_blocks.Tuple((called_intrinsic0, called_intrinsic1)) comp = calls deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) self.assertTrue(deduped_modified) fed_agg = [] def _find_called_federated_aggregate(comp): if (isinstance(comp, building_blocks.Call) and isinstance(comp.function, building_blocks.Intrinsic) and comp.function.uri == intrinsic_defs.FEDERATED_AGGREGATE.uri): fed_agg.append(comp.function) return comp, False transformation_utils.transform_postorder( deduped_and_merged_comp, _find_called_federated_aggregate) self.assertLen(fed_agg, 1) self.assertEqual( fed_agg[0].type_signature.parameter[0].compact_representation(), '{<int32>}@CLIENTS')
def federated_weighted_mean(arg): py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) w = building_blocks.Selection(arg, index=1) multiplied = generic_multiply(arg) zip_arg = building_blocks.Struct([(None, multiplied), (None, w)]) summed = federated_sum( building_block_factory.create_federated_zip(zip_arg)) return generic_divide(summed)
def test_propogates_dependence_up_through_selection(self): type_signature = computation_types.StructType([tf.int32]) whimsy_intrinsic = building_blocks.Intrinsic('whimsy_intrinsic', type_signature) selection = building_blocks.Selection(whimsy_intrinsic, index=0) dependent_nodes = tree_analysis.extract_nodes_consuming( selection, whimsy_intrinsic_predicate) self.assertIn(selection, dependent_nodes)
def test_returns_called_tf_computation_with_truct(self): constant_tuple_type = computation_types.StructType( [tf.int32, tf.float32]) constant_tuple = building_block_factory.create_tensorflow_constant( constant_tuple_type, 1) sel = building_blocks.Selection(source=constant_tuple, index=0) tup = building_blocks.Struct([sel, sel, sel]) self.assert_compiles_to_tensorflow(tup)
def test_with_structure_replacing_federated_map(self): function_type = computation_types.FunctionType(tf.int32, tf.int32) tuple_ref = building_blocks.Reference('arg', [ function_type, tf.int32, ]) fn = building_blocks.Selection(tuple_ref, index=0) arg = building_blocks.Selection(tuple_ref, index=1) called_fn = building_blocks.Call(fn, arg) concrete_fn = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) concrete_arg = building_blocks.Data('a', tf.int32) arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg]) generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( generated_structure) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def _traverse_selection(comp, transform, context_tree, identifier_seq): """Helper function holding traversal logic for selection nodes.""" _ = next(identifier_seq) source, source_modified = _transform_postorder_with_symbol_bindings_switch( comp.source, transform, context_tree, identifier_seq) if source_modified: comp = building_blocks.Selection(source, comp.name, comp.index) comp, comp_modified = transform(comp, context_tree) return comp, comp_modified or source_modified
def _build(comp, scope): """Transforms `comp` to CDF, possibly adding bindings to `scope`.""" # The structure returned by this function is a generalized version of # call-dominant form. This function may result in the patterns specified in # the top-level function's docstring. if comp.is_reference(): return scope.resolve(comp.name) elif comp.is_selection(): source = _build(comp.source, scope) if source.is_struct(): return source[comp.as_index()] return building_blocks.Selection(source, index=comp.as_index()) elif comp.is_struct(): elements = [] for (name, value) in structure.iter_elements(comp): value = _build(value, scope) elements.append((name, value)) return building_blocks.Struct(elements) elif comp.is_call(): function = _build(comp.function, scope) argument = None if comp.argument is None else _build( comp.argument, scope) if function.is_lambda(): if argument is not None: scope = scope.new_child() scope.add_local(function.parameter_name, argument) return _build(function.result, scope) else: return scope.create_binding( building_blocks.Call(function, argument)) elif comp.is_lambda(): scope = scope.new_child_with_bindings() if comp.parameter_name: scope.add_local( comp.parameter_name, building_blocks.Reference(comp.parameter_name, comp.parameter_type)) result = _build(comp.result, scope) block = scope.bindings_to_block_with_result(result) return building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block) elif comp.is_block(): scope = scope.new_child() for (name, value) in comp.locals: scope.add_local(name, _build(value, scope)) return _build(comp.result, scope) elif (comp.is_intrinsic() or comp.is_data() or comp.is_compiled_computation()): _disallow_higher_order(comp, global_comp) return comp elif comp.is_placement(): raise ValueError( f'Found placement {comp} in\n{global_comp}\n' 'but placements are not allowed in local computations.') else: raise ValueError( f'Unrecognized computation kind\n{comp}\nin\n{global_comp}')
def _transform(comp): if not _should_transform(comp): return comp, False if len(intrinsics) > 1: index = intrinsics.index(comp) comp = building_blocks.Selection(ref, index=index) return comp, True else: return ref, True
def test_returns_string_for_selection_with_name(self): ref = building_blocks.Reference('a', (('b', tf.int32), ('c', tf.bool))) comp = building_blocks.Selection(ref, name='b') compact_string = comp.compact_representation() self.assertEqual(compact_string, 'a.b') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'a.b') structural_string = comp.structural_representation() # pyformat: disable self.assertEqual(structural_string, 'Sel(b)\n' '|\n' 'Ref(a)')
def test_binding_multiple_args_results_in_unique_names(self): fed_at_clients = computation_types.FederatedType( tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType( tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.StructType( [[fed_at_clients], fed_at_server, [fed_at_clients]]) first_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0), index=0) second_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=2), index=0) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Struct([first_selection, second_selection])) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0, 0), (2, 0)]) tree_analysis.check_has_unique_names(new_lam)
def test_binding_multiple_args_results_in_unique_names(self): fed_at_clients = computation_types.FederatedType( tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType( tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.NamedTupleType( [[fed_at_clients], fed_at_server, [fed_at_clients]]) first_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0), index=0) second_selection = building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=2), index=0) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Tuple([first_selection, second_selection])) deep_zeroth_index_extracted = mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0, 0], [2, 0]]) tree_analysis.check_has_unique_names(deep_zeroth_index_extracted)
def test_returns_string_for_selection_with_index(self): ref = building_blocks.Reference('a', (('b', tf.int32), ('c', tf.bool))) comp = building_blocks.Selection(ref, index=0) self.assertEqual(comp.compact_representation(), 'a[0]') self.assertEqual(comp.formatted_representation(), 'a[0]') # pyformat: disable self.assertEqual(comp.structural_representation(), 'Sel(0)\n' '|\n' 'Ref(a)')
def _create_next_with_fake_client_output(tree): r"""Creates a next computation with a fake client output. This function returns the AST: Lambda | [Comp, Comp, Tuple] | [] In the AST, `Lambda` and the first two `Comps`s in the result of `Lambda` are `tree` and the empty `Tuple` is the fake client output. This function is intended to be used by `get_canonical_form_for_iterative_process` to create a next computation with a fake client output when no client output is returned by `tree` (which represents the `next` function of the `tff.utils.IterativeProcess`). As a result, this function does not assert that there is no client output in `tree` and it does not assert that `tree` has the expected structure, the caller is expected to perform these checks before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A new `building_blocks.ComputationBuildingBlock` representing a next computaiton with a fake client output. """ if isinstance(tree.result, building_blocks.Tuple): arg_1 = tree.result[0] arg_2 = tree.result[1] else: arg_1 = building_blocks.Selection(tree.result, index=0) arg_2 = building_blocks.Selection(tree.result, index=1) empty_tuple = building_blocks.Tuple([]) client_output = building_block_factory.create_federated_value( empty_tuple, placements.CLIENTS) output = building_blocks.Tuple([arg_1, arg_2, client_output]) return building_blocks.Lambda(tree.parameter_name, tree.parameter_type, output)
def test_returns_tf_computation_with_functional_type(self): param = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) sel = building_blocks.Selection(source=param, index=0) tup = building_blocks.Tuple([sel, sel, sel]) lam = building_blocks.Lambda(param.name, param.type_signature, tup) transformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs( lam) self.assertTrue(modified_indicator) self.assertIsInstance(transformed, building_blocks.CompiledComputation) self.assertEqual(transformed.type_signature, lam.type_signature)
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type_with_names( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.bool]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Struct([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertEqual(exec_lambda({ 'a': 9, 'b': 10., 'c': False }), exec_tf({ 'a': 9, 'b': 10., 'c': False }))
def test_selection(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.StructType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) new_lam = form_utils._as_function_of_single_subparameter(lam, 0) self.assert_selected_param_to_result_type(lam, new_lam, 0)
def test_raises_on_selections_at_different_placements(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.StructType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) with self.assertRaises(form_utils._MismatchedSelectionPlacementError): form_utils._as_function_of_some_federated_subparameters(lam, [(0,), (1,)])
def test_single_nested_element_selection(self): fed_at_clients = computation_types.FederatedType( tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType( tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.StructType( [[fed_at_clients], fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection(building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0), index=0)) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0, 0)]) expected_parameter_type = computation_types.at_clients((tf.int32, )) type_test_utils.assert_types_equivalent( new_lam.type_signature, computation_types.FunctionType(expected_parameter_type, lam.result.type_signature))
def test_returns_called_tf_computation_with_non_functional_type(self): constant_tuple = building_block_factory.create_tensorflow_constant( [tf.int32, tf.float32], 1) sel = building_blocks.Selection(source=constant_tuple, index=0) tup = building_blocks.Tuple([sel, sel, sel]) transformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs( tup) self.assertTrue(modified_indicator) self.assertEqual(transformed.type_signature, tup.type_signature) self.assertIsInstance(transformed, building_blocks.Call) self.assertIsInstance(transformed.function, building_blocks.CompiledComputation) self.assertIsNone(transformed.argument)
def test_broadcast_dependent_on_aggregate_fails_well(self): mrf = mapreduce_test_utils.get_temperature_sensor_example() it = form_utils.get_iterative_process_for_map_reduce_form(mrf) next_comp = it.next.to_building_block() top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Struct([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)
def test_raises_on_selections_at_different_placements(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.NamedTupleType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) with self.assertRaisesRegex(ValueError, 'at the same placement.'): mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0], [1]])
def select_output_from_lambda(comp, indices): """Constructs a new function with result of selecting `indices` from `comp`. Args: comp: Instance of `building_blocks.Lambda` of result type `tff.StructType` from which we wish to select `indices`. Notice that this named tuple type must have elements of federated type. indices: Instance of `int`, `list`, or `tuple`, specifying the indices we wish to select from the result of `comp`. If `indices` is an `int`, the result of the returned `comp` will be of type at index `indices` in `comp.type_signature.result`. If `indices` is a `list` or `tuple`, the result type will be a `tff.StructType` wrapping the specified selections. Returns: A transformed version of `comp` with result value the selection from the result of `comp` specified by `indices`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.type_signature.result, computation_types.StructType) py_typecheck.check_type(indices, (int, tuple, list)) def _create_selected_output(comp, index, is_struct_opt): if is_struct_opt: return comp[index] else: return building_blocks.Selection(comp, index=index) result_tuple = comp.result tuple_opt = result_tuple.is_struct() elements = [] if isinstance(indices, (tuple, list)): for x in indices: if isinstance(x, (tuple, list)): selected_output = result_tuple for y in x: tuple_opt = selected_output.is_struct() selected_output = _create_selected_output( selected_output, y, tuple_opt) else: selected_output = _create_selected_output( result_tuple, x, tuple_opt) elements.append(selected_output) result = building_blocks.Struct(elements) else: if tuple_opt: result = result_tuple[indices] else: result = building_blocks.Selection(result_tuple, index=indices) return building_blocks.Lambda(comp.parameter_name, comp.parameter_type, result)
def test_binds_multiple_args_deep_in_type_tree_to_lower_lambda(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.NamedTupleType( [[fed_at_clients], fed_at_server, [fed_at_clients]]) first_selection = building_blocks.Selection( building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0), index=0) second_selection = building_blocks.Selection( building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=2), index=0) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Tuple([first_selection, second_selection])) expected_fn_regex = (r'\(_([a-z]{3})2 -> <federated_map\(<\(_(\1)3 -> ' r'_(\1)3\[0\]\),_(\1)2>\),federated_map\(<\(_(\1)4 -> ' r'_(\1)4\[1\]\),_(\1)2>\)>\)') expected_arg_regex = r'federated_zip_at_clients\(<_([a-z]{3})1\[0\]\[0\],_(\1)1\[2\]\[0\]>\)' deep_zeroth_index_extracted = transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0, 0], [2, 0]]) self.assertEqual(deep_zeroth_index_extracted.type_signature, lam.type_signature) self.assertIsInstance(deep_zeroth_index_extracted, building_blocks.Lambda) self.assertIsInstance(deep_zeroth_index_extracted.result, building_blocks.Call) self.assertIsInstance(deep_zeroth_index_extracted.result.function, building_blocks.Lambda) self.assertRegexMatch( deep_zeroth_index_extracted.result.function.compact_representation(), [expected_fn_regex]) self.assertRegexMatch( deep_zeroth_index_extracted.result.argument.compact_representation(), [expected_arg_regex])
def test_replaces_lambda_to_unnamed_tuple_of_called_graphs_with_tf_of_same_type( self): int_tensor_type = computation_types.TensorType(tf.int32) int_identity_tf_block = building_block_factory.create_compiled_identity( int_tensor_type) float_tensor_type = computation_types.TensorType(tf.float32) float_identity_tf_block = building_block_factory.create_compiled_identity( float_tensor_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_float = building_blocks.Selection(tuple_ref, index=1) called_int_tf_block = building_blocks.Call(int_identity_tf_block, selected_int) called_float_tf_block = building_blocks.Call(float_identity_tf_block, selected_float) tuple_of_called_graphs = building_blocks.Struct( [called_int_tf_block, called_float_tf_block]) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], tuple_of_called_graphs) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertEqual(exec_lambda([11, 12.]), exec_tf([11, 12.]))
def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self): first_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast( ) packed_broadcast = building_blocks.Struct([ building_blocks.Data('a', computation_types.at_server(tf.int32)), first_broadcast ]) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast( sel) result = transformations.to_call_dominant(second_broadcast) comp = building_blocks.Lambda('a', tf.int32, result) call = building_block_factory.create_null_federated_broadcast() self.assert_splits_on(comp, call)
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = building_blocks.Reference(arg_name, arg_type) arg_f = building_blocks.Selection(arg, name='f') arg_x = building_blocks.Selection(arg, name='x') x = building_blocks.Lambda( arg_name, arg_type, building_blocks.Call(arg_f, building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual(x.result.compact_representation(), 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(x.compact_representation(), '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_single_element_selection_leaves_no_unbound_references(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.StructType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0,)]) unbound_references = transformation_utils.get_map_of_unbound_references( new_lam)[new_lam] self.assertEmpty(unbound_references)
def test_binding_single_arg_leaves_no_unbound_references(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.NamedTupleType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) zeroth_index_extracted = mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]]) unbound_references = transformations.get_map_of_unbound_references( zeroth_index_extracted)[zeroth_index_extracted] self.assertEmpty(unbound_references)
def __getattr__(self, name): py_typecheck.check_type(name, str) _check_struct_or_federated_struct(self, name) if _is_federated_named_tuple(self): return ValueImpl( building_block_factory.create_federated_getattr_call( self._comp, name), self._context_stack) if name not in dir(self.type_signature): raise AttributeError( 'There is no such attribute \'{}\' in this tuple. Valid attributes: ({})' .format(name, ', '.join(dir(self.type_signature)))) if self._comp.is_struct(): return ValueImpl(getattr(self._comp, name), self._context_stack) return ValueImpl(building_blocks.Selection(self._comp, name=name), self._context_stack)
def _traverse_selection(comp, transform, context_tree, identifier_seq): """Helper function holding traversal logic for selection nodes.""" _ = next(identifier_seq) source, source_modified = _transform_postorder_with_symbol_bindings_switch( comp.source, transform, context_tree, identifier_seq) if source_modified: # Normalize selection to index based on the type signature of the # original source. The new source may not have names present. if comp.index is not None: index = comp.index else: index = structure.name_to_index_map( comp.source.type_signature)[comp.name] comp = building_blocks.Selection(source, index=index) comp, comp_modified = transform(comp, context_tree) return comp, comp_modified or source_modified