Example #1
0
  def test_type_properties_unweighted(self, value_type):
    value_type = computation_types.to_type(value_type)

    factory_ = mean_factory.UnweightedMeanFactory()
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(value_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(())
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
Example #2
0
  def test_incorrect_create_type_raises(self, wrong_type):
    factory_ = mean_factory.MeanFactory()
    correct_type = computation_types.to_type(tf.float32)
    with self.assertRaises(TypeError):
      factory_.create(wrong_type, correct_type)
    with self.assertRaises(TypeError):
      factory_.create(correct_type, wrong_type)

    factory_ = mean_factory.UnweightedMeanFactory()
    with self.assertRaises(TypeError):
      factory_.create(wrong_type)
Example #3
0
  def test_structure_value_unweighted(self):
    factory_ = mean_factory.UnweightedMeanFactory()
    value_type = computation_types.to_type(_test_struct_type)
    process = factory_.create(value_type)
    expected_state = ()
    expected_measurements = collections.OrderedDict(mean_value=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)]
    output = process.next(state, client_data)

    self.assertAllEqual(expected_state, output.state)
    self.assertAllClose(((2.0, 7 / 3), 4.0), output.result)
    self.assertEqual(expected_measurements, output.measurements)
Example #4
0
  def test_inner_value_sum_factory_unweighted(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.UnweightedMeanFactory(value_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type)

    state = process.initialize()
    self.assertAllEqual(0, state)

    # Values will be summed to 7.0.
    client_data = [1.0, 2.0, 3.0]

    output = process.next(state, client_data)
    self.assertAllEqual(1, output.state)
    self.assertAllClose(7 / 3, output.result)
    self.assertEqual(
        collections.OrderedDict(mean_value=M_CONST), output.measurements)
Example #5
0
  def test_scalar_value_unweighted(self):
    factory_ = mean_factory.UnweightedMeanFactory()
    value_type = computation_types.to_type(tf.float32)

    process = factory_.create(value_type)
    expected_state = ()
    expected_measurements = collections.OrderedDict(mean_value=())

    state = process.initialize()
    self.assertAllEqual(expected_state, state)

    client_data = [1.0, 2.0, 3.0]
    output = process.next(state, client_data)
    self.assertAllClose(2.0, output.result)

    self.assertAllEqual(expected_state, output.state)
    self.assertEqual(expected_measurements, output.measurements)