Example #1
0
    def test_init_raises_type_error_with_datasets(self, datasets):
        federated_type = computation_types.FederatedType(
            computation_types.SequenceType(tf.int32), placements.CLIENTS)

        with self.assertRaises(TypeError):
            native_platform.DatasetDataSourceIterator(
                datasets=datasets, federated_type=federated_type)
Example #2
0
    def test_select_raises_value_error(self, number_of_clients):
        datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3
        federated_type = computation_types.FederatedType(
            computation_types.SequenceType(tf.int32), placements.CLIENTS)
        iterator = native_platform.DatasetDataSourceIterator(
            datasets=datasets, federated_type=federated_type)

        with self.assertRaises(ValueError):
            iterator.select(number_of_clients)
Example #3
0
    def test_init_does_not_raise_type_error(self):
        datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3
        federated_type = computation_types.FederatedType(
            computation_types.SequenceType(tf.int32), placements.CLIENTS)

        try:
            native_platform.DatasetDataSourceIterator(
                datasets=datasets, federated_type=federated_type)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')
Example #4
0
    def test_init_raises_value_error_with_datasets_different_types(self):
        datasets = [
            tf.data.Dataset.from_tensor_slices([1, 2, 3]),
            tf.data.Dataset.from_tensor_slices(['a', 'b', 'c']),
        ]
        federated_type = computation_types.FederatedType(
            computation_types.SequenceType(tf.int32), placements.CLIENTS)

        with self.assertRaises(ValueError):
            native_platform.DatasetDataSourceIterator(
                datasets=datasets, federated_type=federated_type)
Example #5
0
    def test_select_returns_data(self, number_of_clients):
        datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3
        federated_type = computation_types.FederatedType(
            computation_types.SequenceType(tf.int32), placements.CLIENTS)
        iterator = native_platform.DatasetDataSourceIterator(
            datasets=datasets, federated_type=federated_type)

        data = iterator.select(number_of_clients)

        self.assertLen(data, number_of_clients)
        for actual_dataset in data:
            expected_dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
            self.assertSameElements(actual_dataset, expected_dataset)
Example #6
0
    def test_init_raises_type_error_with_federated_type(self, federated_type):
        datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3

        with self.assertRaises(TypeError):
            native_platform.DatasetDataSourceIterator(
                datasets=datasets, federated_type=federated_type)