state = iterative_process.initialize() # SGD keeps track of a single scalar for the number of iterations. self.assertAllEqual(state.optimizer_state, [0]) self.assertAllClose(list(state.model.trainable), [np.zeros((2, 1)), 0.0]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertEqual(state.delta_aggregate_state, 0) self.assertEqual(state.model_broadcast_state, 0) state, outputs = iterative_process.next(state, federated_ds) self.assertAllClose(list(state.model.trainable), [-np.ones((2, 1)), -1.0 * learning_rate]) self.assertAllClose(list(state.model.non_trainable), [0.0]) self.assertAllEqual(state.optimizer_state, [1]) self.assertEqual(state.delta_aggregate_state, 1) self.assertEqual(state.model_broadcast_state, 1) expected_outputs = collections.OrderedDict( broadcast=3.0, aggregation=collections.OrderedDict(num_clients=3), train={ 'loss': 15.25, 'num_examples': 6, }) self.assertEqual(str(expected_outputs), str(outputs)) if __name__ == '__main__': execution_contexts.set_local_execution_context() test_utils.main()
def test_inner_value_and_weight_sum_factory(self): sum_factory = aggregators_test_utils.SumPlusOneFactory() mean_f = mean_factory.MeanFactory(value_sum_factory=sum_factory, weight_sum_factory=sum_factory) value_type = computation_types.to_type(tf.float32) process = mean_f.create(value_type) state = process.initialize() self.assertAllEqual( collections.OrderedDict(value_sum_process=0, weight_sum_process=0), state) client_data = [1.0, 2.0, 3.0] weights = [1.0, 1.0, 1.0] # Weighted values will be summed to 7.0 and weights will be summed to 4.0. output = process.next(state, client_data, weights) self.assertAllEqual( collections.OrderedDict(value_sum_process=1, weight_sum_process=1), output.state) self.assertAllClose(7 / 4, output.result) self.assertEqual( collections.OrderedDict(value_sum_process=M_CONST, weight_sum_process=M_CONST), output.measurements) if __name__ == '__main__': execution_contexts.set_local_execution_context() common_libs_test_utils.main()