def test_constructs_broadcast_of_tuple_with_one_element(self): called_intrinsic = test_utils.create_dummy_called_federated_broadcast() calls = building_blocks.Tuple((called_intrinsic, called_intrinsic)) comp = calls transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_BROADCAST.uri) federated_broadcast = [] def _find_federated_broadcast(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_BROADCAST.uri): federated_broadcast.append(comp) return comp, False transformation_utils.transform_postorder(transformed_comp, _find_federated_broadcast) self.assertTrue(modified) self.assertEqual( comp.compact_representation(), '<federated_broadcast(data),federated_broadcast(data)>') self.assertLen(federated_broadcast, 1) self.assertLen(federated_broadcast[0].type_signature.member, 1) self.assertEqual( transformed_comp.formatted_representation(), '(_var1 -> <\n' ' _var1[0],\n' ' _var1[0]\n' '>)((x -> <\n' ' x[0]\n' '>)((let\n' ' value=federated_broadcast(federated_apply(<\n' ' (arg -> <\n' ' arg\n' ' >),\n' ' <\n' ' data\n' ' >[0]\n' ' >))\n' ' in <\n' ' federated_map_all_equal(<\n' ' (arg -> arg[0]),\n' ' value\n' ' >)\n' '>)))')
def test_returns_trees_with_two_federated_broadcast(self): federated_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast( ) called_intrinsics = building_blocks.Struct([ federated_broadcast, federated_broadcast, ]) comp = building_blocks.Lambda('a', tf.int32, called_intrinsics) uri = [intrinsic_defs.FEDERATED_BROADCAST.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_does_not_find_aggregate_dependent_on_broadcast(self): broadcast = test_utils.create_dummy_called_federated_broadcast() value_type = broadcast.type_signature zero = building_blocks.Data('zero', value_type.member) accumulate_result = building_blocks.Data('accumulate_result', value_type.member) accumulate = building_blocks.Lambda('accumulate_parameter', [value_type.member, value_type.member], accumulate_result) merge_result = building_blocks.Data('merge_result', value_type.member) merge = building_blocks.Lambda('merge_parameter', [value_type.member, value_type.member], merge_result) report_result = building_blocks.Data('report_result', value_type.member) report = building_blocks.Lambda('report_parameter', value_type.member, report_result) aggregate_dependent_on_broadcast = building_block_factory.create_federated_aggregate( broadcast, zero, accumulate, merge, report) tree_analysis.check_broadcast_not_dependent_on_aggregate( aggregate_dependent_on_broadcast)
def test_handles_federated_broadcasts_nested_in_tuple(self): first_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast( ) packed_broadcast = building_blocks.Tuple([ building_blocks.Data( 'a', computation_types.FederatedType( computation_types.TensorType(tf.int32), placements.SERVER)), first_broadcast ]) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast(sel) comp = building_blocks.Lambda('a', tf.int32, second_broadcast) uri = [intrinsic_defs.FEDERATED_BROADCAST.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_returns_false_with_unmatched_called_intrinsic(self): comp = test_utils.create_dummy_called_federated_broadcast() uri = intrinsic_defs.FEDERATED_MAP.uri self.assertFalse(tree_analysis.contains_called_intrinsic(comp, uri))
def test_returns_true_with_matching_uri(self): comp = test_utils.create_dummy_called_federated_broadcast() uri = intrinsic_defs.FEDERATED_BROADCAST.uri self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uri))
def test_returns_true_with_none_uri(self): comp = test_utils.create_dummy_called_federated_broadcast() self.assertTrue(tree_analysis.contains_called_intrinsic(comp))
def test_returns_true_for_tuples(self): data_1 = building_blocks.Data('data', tf.int32) tuple_1 = building_blocks.Struct([data_1, data_1]) data_2 = building_blocks.Data('data', tf.int32) tuple_2 = building_blocks.Struct([data_2, data_2]) self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2)) def test_returns_true_for_identical_graphs_with_nans(self): tf_comp1 = _create_tensorflow_graph_with_nan() tf_comp2 = _create_tensorflow_graph_with_nan() self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2)) non_aggregation_intrinsics = building_blocks.Struct([ (None, test_utils.create_dummy_called_federated_broadcast()), (None, test_utils.create_dummy_called_federated_value(placements.CLIENTS)) ]) unit = computation_types.StructType([]) trivial_aggregate = test_utils.create_dummy_called_federated_aggregate( value_type=unit) trivial_collect = test_utils.create_dummy_called_federated_collect(unit) trivial_mean = test_utils.create_dummy_called_federated_mean(unit) trivial_sum = test_utils.create_dummy_called_federated_sum(unit) # TODO(b/120439632) Enable once federated_mean accepts structured weights. # trivial_weighted_mean = ... trivial_secure_sum = test_utils.create_dummy_called_federated_secure_sum(unit) class ContainsAggregationShared(parameterized.TestCase):