コード例 #1
0
    def test_identical_to_merge_tuple_intrinsics_with_different_intrinsics(
            self):
        called_intrinsic1 = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c',
            value_type=tf.int32)
        # These compare as not equal.
        called_intrinsic2 = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='x',
            merge_parameter_name='y',
            report_parameter_name='z',
            value_type=tf.float32)
        calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2))
        comp = calls

        deduped_and_merged_comp, deduped_modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
        directly_merged_comp, directly_modified = transformations.merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        self.assertTrue(deduped_modified)
        self.assertTrue(directly_modified)
        self.assertEqual(deduped_and_merged_comp.formatted_representation(),
                         directly_merged_comp.formatted_representation())
コード例 #2
0
    def test_returns_comps_with_federated_aggregate_no_unbound_references(
            self):
        federated_aggregate = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        tup = building_blocks.Tuple([
            federated_aggregate,
            federated_aggregate,
        ])
        comp = building_blocks.Lambda('d', tf.int32, tup)

        uri = intrinsic_defs.FEDERATED_AGGREGATE.uri

        before, after = mapreduce_transformations.force_align_and_split_by_intrinsic(
            comp, uri)

        def _predicate(comp):
            return building_block_analysis.is_called_intrinsic(comp, uri)

        self.assertIsInstance(comp, building_blocks.Lambda)
        self.assertGreater(tree_analysis.count(comp, _predicate), 0)
        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertEqual(tree_analysis.count(before, _predicate), 0)
        self.assertEqual(before.parameter_type, comp.parameter_type)
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertEqual(tree_analysis.count(after, _predicate), 0)
        self.assertEqual(after.result.type_signature,
                         comp.result.type_signature)
コード例 #3
0
  def test_returns_string_for_federated_aggregate(self):
    comp = test_utils.create_dummy_called_federated_aggregate(
        accumulate_parameter_name='a',
        merge_parameter_name='b',
        report_parameter_name='c')

    self.assertEqual(
        comp.compact_representation(),
        'federated_aggregate(<data,data,(a -> data),(b -> data),(c -> data)>)')
    # pyformat: disable
    self.assertEqual(
        comp.formatted_representation(),
        'federated_aggregate(<\n'
        '  data,\n'
        '  data,\n'
        '  (a -> data),\n'
        '  (b -> data),\n'
        '  (c -> data)\n'
        '>)'
    )
    self.assertEqual(
        comp.structural_representation(),
        '                    Call\n'
        '                   /    \\\n'
        'federated_aggregate      Struct\n'
        '                         |\n'
        '                         [data, data, Lambda(a), Lambda(b), Lambda(c)]\n'
        '                                      |          |          |\n'
        '                                      data       data       data'
    )
コード例 #4
0
 def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self):
   aggregate = test_utils.create_dummy_called_federated_aggregate()
   broadcasted_aggregate = building_block_factory.create_federated_broadcast(
       aggregate)
   with self.assertRaisesRegex(ValueError, 'acc_param'):
     tree_analysis.check_broadcast_not_dependent_on_aggregate(
         broadcasted_aggregate)
コード例 #5
0
 def test_finds_broadcast_dependent_on_aggregate(self):
   aggregate = test_utils.create_dummy_called_federated_aggregate()
   broadcasted_aggregate = building_block_factory.create_federated_broadcast(
       aggregate)
   with self.assertRaises(ValueError):
     tree_analysis.check_broadcast_not_dependent_on_aggregate(
         broadcasted_aggregate)
コード例 #6
0
  def test_returns_trees_with_one_federated_secure_sum_and_two_federated_aggregates(
      self):
    federated_aggregate = compiler_test_utils.create_dummy_called_federated_aggregate(
        accumulate_parameter_name='a',
        merge_parameter_name='b',
        report_parameter_name='c')
    federated_secure_sum = compiler_test_utils.create_dummy_called_federated_secure_sum(
    )
    called_intrinsics = building_blocks.Struct([
        federated_secure_sum,
        federated_aggregate,
        federated_aggregate,
    ])
    comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
    uri = [
        intrinsic_defs.FEDERATED_AGGREGATE.uri,
        intrinsic_defs.FEDERATED_SECURE_SUM.uri,
    ]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
コード例 #7
0
  def test_raises_value_error_for_expected_uri(self):
    federated_aggregate = compiler_test_utils.create_dummy_called_federated_aggregate(
        accumulate_parameter_name='a',
        merge_parameter_name='b',
        report_parameter_name='c')
    called_intrinsics = building_blocks.Struct([federated_aggregate])
    comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
    uri = [
        intrinsic_defs.FEDERATED_AGGREGATE.uri,
        intrinsic_defs.FEDERATED_SECURE_SUM.uri,
    ]

    with self.assertRaises(ValueError):
      transformations.force_align_and_split_by_intrinsics(comp, uri)
コード例 #8
0
    tuple_2 = building_blocks.Struct([data_2, data_2])
    self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2))

  def test_returns_true_for_identical_graphs_with_nans(self):
    tf_comp1 = _create_tensorflow_graph_with_nan()
    tf_comp2 = _create_tensorflow_graph_with_nan()
    self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2))


non_aggregation_intrinsics = building_blocks.Struct([
    (None, test_utils.create_dummy_called_federated_broadcast()),
    (None, test_utils.create_dummy_called_federated_value(placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = test_utils.create_dummy_called_federated_aggregate(
    value_type=unit)
trivial_collect = test_utils.create_dummy_called_federated_collect(unit)
trivial_mean = test_utils.create_dummy_called_federated_mean(unit)
trivial_sum = test_utils.create_dummy_called_federated_sum(unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = test_utils.create_dummy_called_federated_secure_sum(unit)


class ContainsAggregationShared(parameterized.TestCase):

  @parameterized.named_parameters([
      ('non_aggregation_intrinsics', non_aggregation_intrinsics),
      ('trivial_aggregate', trivial_aggregate),
      ('trivial_collect', trivial_collect),
      ('trivial_mean', trivial_mean),
コード例 #9
0
    def test_constructs_aggregate_of_tuple_with_one_element(self):
        called_intrinsic = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        calls = building_blocks.Tuple((called_intrinsic, called_intrinsic))
        comp = calls

        transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        federated_agg = []

        def _find_federated_aggregate(comp):
            if building_block_analysis.is_called_intrinsic(
                    comp, intrinsic_defs.FEDERATED_AGGREGATE.uri):
                federated_agg.append(comp)
            return comp, False

        transformation_utils.transform_postorder(transformed_comp,
                                                 _find_federated_aggregate)
        self.assertTrue(modified)
        self.assertLen(federated_agg, 1)
        self.assertLen(federated_agg[0].type_signature.member, 1)
        self.assertEqual(
            transformed_comp.formatted_representation(), '(_var1 -> <\n'
            '  _var1[0],\n'
            '  _var1[0]\n'
            '>)((x -> <\n'
            '  x[0]\n'
            '>)((let\n'
            '  value=federated_aggregate(<\n'
            '    federated_map(<\n'
            '      (arg -> <\n'
            '        arg\n'
            '      >),\n'
            '      <\n'
            '        data\n'
            '      >[0]\n'
            '    >),\n'
            '    <\n'
            '      data\n'
            '    >,\n'
            '    (let\n'
            '      _var1=<\n'
            '        (a -> data)\n'
            '      >\n'
            '     in (_var2 -> <\n'
            '      _var1[0](<\n'
            '        <\n'
            '          _var2[0][0],\n'
            '          _var2[1][0]\n'
            '        >\n'
            '      >[0])\n'
            '    >)),\n'
            '    (let\n'
            '      _var3=<\n'
            '        (b -> data)\n'
            '      >\n'
            '     in (_var4 -> <\n'
            '      _var3[0](<\n'
            '        <\n'
            '          _var4[0][0],\n'
            '          _var4[1][0]\n'
            '        >\n'
            '      >[0])\n'
            '    >)),\n'
            '    (let\n'
            '      _var5=<\n'
            '        (c -> data)\n'
            '      >\n'
            '     in (_var6 -> <\n'
            '      _var5[0](_var6[0])\n'
            '    >))\n'
            '  >)\n'
            ' in <\n'
            '  federated_apply(<\n'
            '    (arg -> arg[0]),\n'
            '    value\n'
            '  >)\n'
            '>)))')