def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self):
   federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
       accumulate_parameter_name='a',
       merge_parameter_name='b',
       report_parameter_name='c')
   federated_secure_sum_bitwidth = compiler_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
   )
   called_intrinsics = building_blocks.Struct([
       federated_aggregate,
       federated_secure_sum_bitwidth,
   ])
   comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
   null_aggregate = building_block_factory.create_null_federated_aggregate()
   secure_sum_bitwidth_uri = federated_secure_sum_bitwidth.function.uri
   aggregate_uri = null_aggregate.function.uri
   before, after = transformations.force_align_and_split_by_intrinsics(
       comp, [null_aggregate])
   self.assertTrue(
       tree_analysis.contains_called_intrinsic(comp, secure_sum_bitwidth_uri))
   self.assertTrue(
       tree_analysis.contains_called_intrinsic(comp, aggregate_uri))
   self.assertFalse(
       tree_analysis.contains_called_intrinsic(before, aggregate_uri))
   self.assertFalse(
       tree_analysis.contains_called_intrinsic(after, aggregate_uri))
   self.assertTrue(
       tree_analysis.contains_called_intrinsic(before,
                                               secure_sum_bitwidth_uri) or
       tree_analysis.contains_called_intrinsic(after, secure_sum_bitwidth_uri))
Esempio n. 2
0
    def test_returns_trees_with_one_federated_secure_sum_and_two_federated_aggregates(
            self):
        federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        federated_secure_sum = compiler_test_utils.create_whimsy_called_federated_secure_sum(
        )
        called_intrinsics = building_blocks.Struct([
            federated_secure_sum,
            federated_aggregate,
            federated_aggregate,
        ])
        comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
        uri = [
            intrinsic_defs.FEDERATED_AGGREGATE.uri,
            intrinsic_defs.FEDERATED_SECURE_SUM.uri,
        ]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
    def test_returns_string_for_federated_aggregate(self):
        comp = test_utils.create_whimsy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')

        self.assertEqual(
            comp.compact_representation(),
            'federated_aggregate(<data,data,(a -> data),(b -> data),(c -> data)>)'
        )
        # pyformat: disable
        self.assertEqual(
            comp.formatted_representation(), 'federated_aggregate(<\n'
            '  data,\n'
            '  data,\n'
            '  (a -> data),\n'
            '  (b -> data),\n'
            '  (c -> data)\n'
            '>)')
        self.assertEqual(
            comp.structural_representation(), '                    Call\n'
            '                   /    \\\n'
            'federated_aggregate      Struct\n'
            '                         |\n'
            '                         [data, data, Lambda(a), Lambda(b), Lambda(c)]\n'
            '                                      |          |          |\n'
            '                                      data       data       data')
 def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self):
     aggregate = test_utils.create_whimsy_called_federated_aggregate()
     broadcasted_aggregate = building_block_factory.create_federated_broadcast(
         aggregate)
     with self.assertRaisesRegex(ValueError, 'acc_param'):
         tree_analysis.check_broadcast_not_dependent_on_aggregate(
             broadcasted_aggregate)
 def test_finds_broadcast_dependent_on_aggregate(self):
     aggregate = test_utils.create_whimsy_called_federated_aggregate()
     broadcasted_aggregate = building_block_factory.create_federated_broadcast(
         aggregate)
     with self.assertRaises(ValueError):
         tree_analysis.check_broadcast_not_dependent_on_aggregate(
             broadcasted_aggregate)
 def test_splits_on_selected_intrinsic_aggregate(self):
   federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
       accumulate_parameter_name='a',
       merge_parameter_name='b',
       report_parameter_name='c')
   called_intrinsics = building_blocks.Struct([federated_aggregate])
   comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
   call = building_block_factory.create_null_federated_aggregate()
   self.assert_splits_on(comp, call)
 def test_splits_even_when_selected_intrinsic_is_not_present(self):
   federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
       accumulate_parameter_name='a',
       merge_parameter_name='b',
       report_parameter_name='c')
   called_intrinsics = building_blocks.Struct([federated_aggregate])
   comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
   transformations.force_align_and_split_by_intrinsics(comp, [
       building_block_factory.create_null_federated_aggregate(),
       building_block_factory.create_null_federated_secure_sum_bitwidth(),
   ])
Esempio n. 8
0
    def test_raises_value_error_for_expected_uri(self):
        federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        called_intrinsics = building_blocks.Struct([federated_aggregate])
        comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
        uri = [
            intrinsic_defs.FEDERATED_AGGREGATE.uri,
            intrinsic_defs.FEDERATED_SECURE_SUM.uri,
        ]

        with self.assertRaises(ValueError):
            transformations.force_align_and_split_by_intrinsics(comp, uri)
 def test_splits_on_two_intrinsics(self):
   federated_aggregate = compiler_test_utils.create_whimsy_called_federated_aggregate(
       accumulate_parameter_name='a',
       merge_parameter_name='b',
       report_parameter_name='c')
   federated_secure_sum_bitwidth = compiler_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
   )
   called_intrinsics = building_blocks.Struct([
       federated_aggregate,
       federated_secure_sum_bitwidth,
   ])
   comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
   self.assert_splits_on(comp, [
       building_block_factory.create_null_federated_aggregate(),
       building_block_factory.create_null_federated_secure_sum_bitwidth()
   ])
        tuple_2 = building_blocks.Struct([data_2, data_2])
        self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2))

    def test_returns_true_for_identical_graphs_with_nans(self):
        tf_comp1 = _create_tensorflow_graph_with_nan()
        tf_comp2 = _create_tensorflow_graph_with_nan()
        self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2))


non_aggregation_intrinsics = building_blocks.Struct([
    (None, test_utils.create_whimsy_called_federated_broadcast()),
    (None, test_utils.create_whimsy_called_federated_value(placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = test_utils.create_whimsy_called_federated_aggregate(
    value_type=unit)
trivial_collect = test_utils.create_whimsy_called_federated_collect(unit)
trivial_mean = test_utils.create_whimsy_called_federated_mean(unit)
trivial_sum = test_utils.create_whimsy_called_federated_sum(unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = test_utils.create_whimsy_called_federated_secure_sum_bitwidth(
    unit)


class ContainsAggregationShared(parameterized.TestCase):
    @parameterized.named_parameters([
        ('non_aggregation_intrinsics', non_aggregation_intrinsics),
        ('trivial_aggregate', trivial_aggregate),
        ('trivial_collect', trivial_collect),
        ('trivial_mean', trivial_mean),