Esempio n. 1
0
    def test_simple_sum(self):
        encoded_f = encoded.EncodedSumFactory(_identity_encoder_fn)
        process = encoded_f.create(computation_types.to_type(tf.float32))

        state = process.initialize()

        client_data = [1.0, 2.0, 3.0]
        for _ in range(3):
            output = process.next(state, client_data)
            self.assertAllClose(6.0, output.result)
            self.assertEqual((), output.measurements)
            state = output.state
Esempio n. 2
0
    def test_structure_sum(self):
        encoded_f = encoded.EncodedSumFactory(_identity_encoder_fn)
        process = encoded_f.create(
            computation_types.to_type(((tf.float32, (2, )), tf.float32)))

        state = process.initialize()

        client_data = [
            [[1.0, -1.0], 2],
            [[2.0, 4.0], 3],
            [[3.0, 5.0], 5],
        ]
        for _ in range(3):
            output = process.next(state, client_data)
            self.assertAllClose([[6.0, 8.0], 10], output.result)
            self.assertEqual((), output.measurements)
            state = output.state
Esempio n. 3
0
    def test_type_properties(self, encoder_fn):
        encoded_f = encoded.EncodedSumFactory(encoder_fn)
        self.assertIsInstance(encoded_f, factory.UnweightedAggregationFactory)

        process = encoded_f.create(_test_struct_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        self.assertIsNone(process.initialize.type_signature.parameter)
        server_state_type = process.initialize.type_signature.result
        # State structure should have one element per tensor aggregated,
        self.assertLen(server_state_type.member, 2)
        self.assertEqual(placements.SERVER, server_state_type.placement)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(_test_struct_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(_test_struct_type),
                measurements=computation_types.at_server(())))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Esempio n. 4
0
 def test_encoder_fn_not_callable_raises(self):
     encoder = te.encoders.as_gather_encoder(te.encoders.identity(),
                                             tf.TensorSpec((), tf.float32))
     with self.assertRaises(TypeError):
         encoded.EncodedSumFactory(encoder)