コード例 #1
0
ファイル: backend_test.py プロジェクト: xiaoral2/federated
    def test_get_size_info(self):
        num_clients = 10
        to_float = lambda x: tf.cast(x, tf.float32)
        temperatures = [tf.data.Dataset.range(10).map(to_float)] * num_clients
        threshold = 15.0

        temperature_sensor_example.mean_over_threshold(temperatures, threshold)
        context = tff.framework.get_context_stack().current
        size_info = context.executor_factory.get_size_info()

        # Each client receives a tf.float32 and uploads two tf.float32 values.
        expected_broadcast_bits = [num_clients * 32]
        expected_aggregate_bits = [num_clients * 32 * 2]
        expected_broadcast_history = {
            (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients
        }
        expected_aggregate_history = {
            (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients * 2
        }
        self.assertEqual(size_info.broadcast_history,
                         expected_broadcast_history)
        self.assertEqual(size_info.aggregate_history,
                         expected_aggregate_history)
        self.assertEqual(size_info.broadcast_bits, expected_broadcast_bits)
        self.assertEqual(size_info.aggregate_bits, expected_aggregate_bits)
コード例 #2
0
ファイル: backend_test.py プロジェクト: xiaoral2/federated
    def test_temperature_sensor_example(self):
        to_float = lambda x: tf.cast(x, tf.float32)
        temperatures = [
            tf.data.Dataset.range(20).map(to_float),
            tf.data.Dataset.range(30).map(to_float),
        ]
        threshold = 10.0

        result = temperature_sensor_example.mean_over_threshold(
            temperatures, threshold)
        self.assertEqual(result, 15.)
コード例 #3
0
ファイル: backend_test.py プロジェクト: tensorflow/federated
    def test_temperature_sensor_example_with_clients_lists(self):
        temperatures = [
            [float(x) for x in range(10)],
            [float(x) for x in range(20)],
            [float(x) for x in range(30)],
        ]
        threshold = 10.0

        result = temperature_sensor_example.mean_over_threshold(
            temperatures, threshold)
        self.assertEqual(result, 12.5)