示例#1
0
        state = iterative_process.initialize()
        # SGD keeps track of a single scalar for the number of iterations.
        self.assertAllEqual(state.optimizer_state, [0])
        self.assertAllClose(list(state.model.trainable),
                            [np.zeros((2, 1)), 0.0])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertEqual(state.delta_aggregate_state, 0)
        self.assertEqual(state.model_broadcast_state, 0)

        state, outputs = iterative_process.next(state, federated_ds)
        self.assertAllClose(list(state.model.trainable),
                            [-np.ones((2, 1)), -1.0 * learning_rate])
        self.assertAllClose(list(state.model.non_trainable), [0.0])
        self.assertAllEqual(state.optimizer_state, [1])
        self.assertEqual(state.delta_aggregate_state, 1)
        self.assertEqual(state.model_broadcast_state, 1)

        expected_outputs = collections.OrderedDict(
            broadcast=3.0,
            aggregation=collections.OrderedDict(num_clients=3),
            train={
                'loss': 15.25,
                'num_examples': 6,
            })
        self.assertEqual(str(expected_outputs), str(outputs))


if __name__ == '__main__':
    execution_contexts.set_local_execution_context()
    test_utils.main()
示例#2
0
    def test_inner_value_and_weight_sum_factory(self):
        sum_factory = aggregators_test_utils.SumPlusOneFactory()
        mean_f = mean_factory.MeanFactory(value_sum_factory=sum_factory,
                                          weight_sum_factory=sum_factory)
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=0, weight_sum_process=0),
            state)

        client_data = [1.0, 2.0, 3.0]
        weights = [1.0, 1.0, 1.0]
        # Weighted values will be summed to 7.0 and weights will be summed to 4.0.
        output = process.next(state, client_data, weights)
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=1, weight_sum_process=1),
            output.state)
        self.assertAllClose(7 / 4, output.result)
        self.assertEqual(
            collections.OrderedDict(value_sum_process=M_CONST,
                                    weight_sum_process=M_CONST),
            output.measurements)


if __name__ == '__main__':
    execution_contexts.set_local_execution_context()
    common_libs_test_utils.main()