예제 #1
0
  def test_returns_trees_with_one_federated_secure_sum_and_two_federated_aggregates(
      self):
    federated_aggregate = compiler_test_utils.create_dummy_called_federated_aggregate(
        accumulate_parameter_name='a',
        merge_parameter_name='b',
        report_parameter_name='c')
    federated_secure_sum = compiler_test_utils.create_dummy_called_federated_secure_sum(
    )
    called_intrinsics = building_blocks.Struct([
        federated_secure_sum,
        federated_aggregate,
        federated_aggregate,
    ])
    comp = building_blocks.Lambda('d', tf.int32, called_intrinsics)
    uri = [
        intrinsic_defs.FEDERATED_AGGREGATE.uri,
        intrinsic_defs.FEDERATED_SECURE_SUM.uri,
    ]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
예제 #2
0
  def test_returns_trees_with_one_federated_secure_sum(self):
    federated_secure_sum = compiler_test_utils.create_dummy_called_federated_secure_sum(
    )
    called_intrinsics = building_blocks.Struct([federated_secure_sum])
    comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
    uri = [intrinsic_defs.FEDERATED_SECURE_SUM.uri]

    before, after = transformations.force_align_and_split_by_intrinsics(
        comp, uri)

    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
예제 #3
0
 def test_returns_str_on_nested_secure_aggregation(self):
   comp = test_utils.create_dummy_called_federated_secure_sum(
       (tf.int32, tf.int32))
   self.assert_one_aggregation(comp)
예제 #4
0

non_aggregation_intrinsics = building_blocks.Struct([
    (None, test_utils.create_dummy_called_federated_broadcast()),
    (None, test_utils.create_dummy_called_federated_value(placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = test_utils.create_dummy_called_federated_aggregate(
    value_type=unit)
trivial_collect = test_utils.create_dummy_called_federated_collect(unit)
trivial_mean = test_utils.create_dummy_called_federated_mean(unit)
trivial_sum = test_utils.create_dummy_called_federated_sum(unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = test_utils.create_dummy_called_federated_secure_sum(unit)


class ContainsAggregationShared(parameterized.TestCase):

  @parameterized.named_parameters([
      ('non_aggregation_intrinsics', non_aggregation_intrinsics),
      ('trivial_aggregate', trivial_aggregate),
      ('trivial_collect', trivial_collect),
      ('trivial_mean', trivial_mean),
      ('trivial_sum', trivial_sum),
      # TODO(b/120439632) Enable once federated_mean accepts structured weight.
      # ('trivial_weighted_mean', trivial_weighted_mean),
      ('trivial_secure_sum', trivial_secure_sum),
  ])
  def test_returns_none(self, comp):