def test_get_size_info(self): num_clients = 10 to_float = lambda x: tf.cast(x, tf.float32) temperatures = [tf.data.Dataset.range(10).map(to_float)] * num_clients threshold = 15.0 temperature_sensor_example.mean_over_threshold(temperatures, threshold) context = tff.framework.get_context_stack().current size_info = context.executor_factory.get_size_info() # Each client receives a tf.float32 and uploads two tf.float32 values. expected_broadcast_bits = [num_clients * 32] expected_aggregate_bits = [num_clients * 32 * 2] expected_broadcast_history = { (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients } expected_aggregate_history = { (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients * 2 } self.assertEqual(size_info.broadcast_history, expected_broadcast_history) self.assertEqual(size_info.aggregate_history, expected_aggregate_history) self.assertEqual(size_info.broadcast_bits, expected_broadcast_bits) self.assertEqual(size_info.aggregate_bits, expected_aggregate_bits)
def test_temperature_sensor_example(self): to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float), ] threshold = 10.0 result = temperature_sensor_example.mean_over_threshold( temperatures, threshold) self.assertEqual(result, 15.)
def test_temperature_sensor_example_with_clients_lists(self): temperatures = [ [float(x) for x in range(10)], [float(x) for x in range(20)], [float(x) for x in range(30)], ] threshold = 10.0 result = temperature_sensor_example.mean_over_threshold( temperatures, threshold) self.assertEqual(result, 12.5)