def test_identical_to_merge_tuple_intrinsics_with_different_intrinsics( self): called_intrinsic1 = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c', value_type=tf.int32) # These compare as not equal. called_intrinsic2 = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='x', merge_parameter_name='y', report_parameter_name='z', value_type=tf.float32) calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2)) comp = calls deduped_and_merged_comp, deduped_modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) directly_merged_comp, directly_modified = transformations.merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) self.assertTrue(deduped_modified) self.assertTrue(directly_modified) self.assertEqual(deduped_and_merged_comp.formatted_representation(), directly_merged_comp.formatted_representation())
def test_returns_comps_with_federated_aggregate_no_unbound_references( self): federated_aggregate = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') tup = building_blocks.Tuple([ federated_aggregate, federated_aggregate, ]) comp = building_blocks.Lambda('d', tf.int32, tup) uri = intrinsic_defs.FEDERATED_AGGREGATE.uri before, after = mapreduce_transformations.force_align_and_split_by_intrinsic( comp, uri) def _predicate(comp): return building_block_analysis.is_called_intrinsic(comp, uri) self.assertIsInstance(comp, building_blocks.Lambda) self.assertGreater(tree_analysis.count(comp, _predicate), 0) self.assertIsInstance(before, building_blocks.Lambda) self.assertEqual(tree_analysis.count(before, _predicate), 0) self.assertEqual(before.parameter_type, comp.parameter_type) self.assertIsInstance(after, building_blocks.Lambda) self.assertEqual(tree_analysis.count(after, _predicate), 0) self.assertEqual(after.result.type_signature, comp.result.type_signature)
def test_returns_string_for_federated_aggregate(self): comp = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') self.assertEqual( comp.compact_representation(), 'federated_aggregate(<data,data,(a -> data),(b -> data),(c -> data)>)') # pyformat: disable self.assertEqual( comp.formatted_representation(), 'federated_aggregate(<\n' ' data,\n' ' data,\n' ' (a -> data),\n' ' (b -> data),\n' ' (c -> data)\n' '>)' ) self.assertEqual( comp.structural_representation(), ' Call\n' ' / \\\n' 'federated_aggregate Struct\n' ' |\n' ' [data, data, Lambda(a), Lambda(b), Lambda(c)]\n' ' | | |\n' ' data data data' )
def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self): aggregate = test_utils.create_dummy_called_federated_aggregate() broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaisesRegex(ValueError, 'acc_param'): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def test_finds_broadcast_dependent_on_aggregate(self): aggregate = test_utils.create_dummy_called_federated_aggregate() broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaises(ValueError): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def test_returns_trees_with_one_federated_secure_sum_and_two_federated_aggregates( self): federated_aggregate = compiler_test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') federated_secure_sum = compiler_test_utils.create_dummy_called_federated_secure_sum( ) called_intrinsics = building_blocks.Struct([ federated_secure_sum, federated_aggregate, federated_aggregate, ]) comp = building_blocks.Lambda('d', tf.int32, called_intrinsics) uri = [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_raises_value_error_for_expected_uri(self): federated_aggregate = compiler_test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') called_intrinsics = building_blocks.Struct([federated_aggregate]) comp = building_blocks.Lambda('d', tf.int32, called_intrinsics) uri = [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ] with self.assertRaises(ValueError): transformations.force_align_and_split_by_intrinsics(comp, uri)
tuple_2 = building_blocks.Struct([data_2, data_2]) self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2)) def test_returns_true_for_identical_graphs_with_nans(self): tf_comp1 = _create_tensorflow_graph_with_nan() tf_comp2 = _create_tensorflow_graph_with_nan() self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2)) non_aggregation_intrinsics = building_blocks.Struct([ (None, test_utils.create_dummy_called_federated_broadcast()), (None, test_utils.create_dummy_called_federated_value(placements.CLIENTS)) ]) unit = computation_types.StructType([]) trivial_aggregate = test_utils.create_dummy_called_federated_aggregate( value_type=unit) trivial_collect = test_utils.create_dummy_called_federated_collect(unit) trivial_mean = test_utils.create_dummy_called_federated_mean(unit) trivial_sum = test_utils.create_dummy_called_federated_sum(unit) # TODO(b/120439632) Enable once federated_mean accepts structured weights. # trivial_weighted_mean = ... trivial_secure_sum = test_utils.create_dummy_called_federated_secure_sum(unit) class ContainsAggregationShared(parameterized.TestCase): @parameterized.named_parameters([ ('non_aggregation_intrinsics', non_aggregation_intrinsics), ('trivial_aggregate', trivial_aggregate), ('trivial_collect', trivial_collect), ('trivial_mean', trivial_mean),
def test_constructs_aggregate_of_tuple_with_one_element(self): called_intrinsic = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') calls = building_blocks.Tuple((called_intrinsic, called_intrinsic)) comp = calls transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) federated_agg = [] def _find_federated_aggregate(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri): federated_agg.append(comp) return comp, False transformation_utils.transform_postorder(transformed_comp, _find_federated_aggregate) self.assertTrue(modified) self.assertLen(federated_agg, 1) self.assertLen(federated_agg[0].type_signature.member, 1) self.assertEqual( transformed_comp.formatted_representation(), '(_var1 -> <\n' ' _var1[0],\n' ' _var1[0]\n' '>)((x -> <\n' ' x[0]\n' '>)((let\n' ' value=federated_aggregate(<\n' ' federated_map(<\n' ' (arg -> <\n' ' arg\n' ' >),\n' ' <\n' ' data\n' ' >[0]\n' ' >),\n' ' <\n' ' data\n' ' >,\n' ' (let\n' ' _var1=<\n' ' (a -> data)\n' ' >\n' ' in (_var2 -> <\n' ' _var1[0](<\n' ' <\n' ' _var2[0][0],\n' ' _var2[1][0]\n' ' >\n' ' >[0])\n' ' >)),\n' ' (let\n' ' _var3=<\n' ' (b -> data)\n' ' >\n' ' in (_var4 -> <\n' ' _var3[0](<\n' ' <\n' ' _var4[0][0],\n' ' _var4[1][0]\n' ' >\n' ' >[0])\n' ' >)),\n' ' (let\n' ' _var5=<\n' ' (c -> data)\n' ' >\n' ' in (_var6 -> <\n' ' _var5[0](_var6[0])\n' ' >))\n' ' >)\n' ' in <\n' ' federated_apply(<\n' ' (arg -> arg[0]),\n' ' value\n' ' >)\n' '>)))')