Пример #1
0
    def test_constructs_broadcast_of_tuple_with_one_element(self):
        called_intrinsic = test_utils.create_dummy_called_federated_broadcast()
        calls = building_blocks.Tuple((called_intrinsic, called_intrinsic))
        comp = calls

        transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_BROADCAST.uri)

        federated_broadcast = []

        def _find_federated_broadcast(comp):
            if building_block_analysis.is_called_intrinsic(
                    comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
                federated_broadcast.append(comp)
            return comp, False

        transformation_utils.transform_postorder(transformed_comp,
                                                 _find_federated_broadcast)

        self.assertTrue(modified)
        self.assertEqual(
            comp.compact_representation(),
            '<federated_broadcast(data),federated_broadcast(data)>')

        self.assertLen(federated_broadcast, 1)
        self.assertLen(federated_broadcast[0].type_signature.member, 1)
        self.assertEqual(
            transformed_comp.formatted_representation(), '(_var1 -> <\n'
            '  _var1[0],\n'
            '  _var1[0]\n'
            '>)((x -> <\n'
            '  x[0]\n'
            '>)((let\n'
            '  value=federated_broadcast(federated_apply(<\n'
            '    (arg -> <\n'
            '      arg\n'
            '    >),\n'
            '    <\n'
            '      data\n'
            '    >[0]\n'
            '  >))\n'
            ' in <\n'
            '  federated_map_all_equal(<\n'
            '    (arg -> arg[0]),\n'
            '    value\n'
            '  >)\n'
            '>)))')
Пример #2
0
  def test_returns_trees_with_two_federated_broadcast(self):
    federated_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast(
    )
    called_intrinsics = building_blocks.Struct([
        federated_broadcast,
        federated_broadcast,
    ])
    comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
    uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Пример #3
0
 def test_does_not_find_aggregate_dependent_on_broadcast(self):
   broadcast = test_utils.create_dummy_called_federated_broadcast()
   value_type = broadcast.type_signature
   zero = building_blocks.Data('zero', value_type.member)
   accumulate_result = building_blocks.Data('accumulate_result',
                                            value_type.member)
   accumulate = building_blocks.Lambda('accumulate_parameter',
                                       [value_type.member, value_type.member],
                                       accumulate_result)
   merge_result = building_blocks.Data('merge_result', value_type.member)
   merge = building_blocks.Lambda('merge_parameter',
                                  [value_type.member, value_type.member],
                                  merge_result)
   report_result = building_blocks.Data('report_result', value_type.member)
   report = building_blocks.Lambda('report_parameter', value_type.member,
                                   report_result)
   aggregate_dependent_on_broadcast = building_block_factory.create_federated_aggregate(
       broadcast, zero, accumulate, merge, report)
   tree_analysis.check_broadcast_not_dependent_on_aggregate(
       aggregate_dependent_on_broadcast)
Пример #4
0
  def test_handles_federated_broadcasts_nested_in_tuple(self):
    first_broadcast = compiler_test_utils.create_dummy_called_federated_broadcast(
    )
    packed_broadcast = building_blocks.Tuple([
        building_blocks.Data(
            'a',
            computation_types.FederatedType(
                computation_types.TensorType(tf.int32), placements.SERVER)),
        first_broadcast
    ])
    sel = building_blocks.Selection(packed_broadcast, index=0)
    second_broadcast = building_block_factory.create_federated_broadcast(sel)
    comp = building_blocks.Lambda('a', tf.int32, second_broadcast)
    uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Пример #5
0
 def test_returns_false_with_unmatched_called_intrinsic(self):
     comp = test_utils.create_dummy_called_federated_broadcast()
     uri = intrinsic_defs.FEDERATED_MAP.uri
     self.assertFalse(tree_analysis.contains_called_intrinsic(comp, uri))
Пример #6
0
 def test_returns_true_with_matching_uri(self):
     comp = test_utils.create_dummy_called_federated_broadcast()
     uri = intrinsic_defs.FEDERATED_BROADCAST.uri
     self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uri))
Пример #7
0
 def test_returns_true_with_none_uri(self):
     comp = test_utils.create_dummy_called_federated_broadcast()
     self.assertTrue(tree_analysis.contains_called_intrinsic(comp))
Пример #8
0
  def test_returns_true_for_tuples(self):
    data_1 = building_blocks.Data('data', tf.int32)
    tuple_1 = building_blocks.Struct([data_1, data_1])
    data_2 = building_blocks.Data('data', tf.int32)
    tuple_2 = building_blocks.Struct([data_2, data_2])
    self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2))

  def test_returns_true_for_identical_graphs_with_nans(self):
    tf_comp1 = _create_tensorflow_graph_with_nan()
    tf_comp2 = _create_tensorflow_graph_with_nan()
    self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2))


non_aggregation_intrinsics = building_blocks.Struct([
    (None, test_utils.create_dummy_called_federated_broadcast()),
    (None, test_utils.create_dummy_called_federated_value(placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = test_utils.create_dummy_called_federated_aggregate(
    value_type=unit)
trivial_collect = test_utils.create_dummy_called_federated_collect(unit)
trivial_mean = test_utils.create_dummy_called_federated_mean(unit)
trivial_sum = test_utils.create_dummy_called_federated_sum(unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = test_utils.create_dummy_called_federated_secure_sum(unit)


class ContainsAggregationShared(parameterized.TestCase):