Exemple #1
0
    def test_dataset_shard_error_with_unsupported_dataset_format(self):
        """Tests whether the dataset_shard function raises an error when an unsupported
        dataset format is specified."""
        config = {
            "input": "dataset",
            "input_config": {
                "format": "__UNSUPPORTED_FORMAT__",
                "paths": self.dset_path,
            },
        }

        with self.assertRaises(ValueError):
            get_dataset_and_shards(config)
Exemple #2
0
    def test_dataset_shard_error_with_both_format_and_loader_fn(self):
        """Tests whether the dataset_shard function raises an error when both format
        and loader_fn are specified."""
        dset = ray.data.range(100)
        config = {
            "input": "dataset",
            "input_config": {
                "format": "json",
                "paths": self.dset_path,
                "loader_fn": lambda: dset,
            },
        }

        with self.assertRaises(ValueError):
            get_dataset_and_shards(config)
Exemple #3
0
    def test_dataset_reader_itr_batches(self):
        """Test that the dataset reader iterates over batches of rows correctly."""
        input_config = {"format": "json", "paths": self.dset_path}
        dataset, _ = get_dataset_and_shards({
            "input": "dataset",
            "input_config": input_config
        })

        ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
        reader = DatasetReader(ioctx, dataset)
        assert len(reader.next()) >= 1200
Exemple #4
0
    def test_dataset_shard_with_loader_fn(self):
        """Tests whether the dataset_shard function works correctly with loader_fn."""
        dset = ray.data.range(100)
        config = {
            "input": "dataset",
            "input_config": {
                "loader_fn": lambda: dset
            }
        }

        ret_dataset, _ = get_dataset_and_shards(config)
        assert ret_dataset.count() == dset.count()
Exemple #5
0
    def test_dataset_shard_with_only_local(self):
        """Tests whether the dataset_shard function works correctly for a single shard
        for the local worker."""
        config = {
            "input": "dataset",
            "input_config": {
                "format": "json",
                "paths": self.dset_path
            },
        }

        # two ways of doing this:

        # we have no remote workers
        _, shards = get_dataset_and_shards(config, num_workers=0)

        assert len(shards) == 1
        assert isinstance(shards[0], ray.data.Dataset)
    def test_itr_batches(self):
        """Test that the json reader iterates over batches of rows correctly."""
        rllib_dir = Path(__file__).parent.parent.parent.parent
        print("rllib dir={}".format(rllib_dir))
        data_file = os.path.join(rllib_dir,
                                 "rllib/tests/data/pendulum/large.json")
        print("data_file={} exists={}".format(data_file,
                                              os.path.isfile(data_file)))
        input_config = {"format": "json", "path": data_file}
        dataset, _ = get_dataset_and_shards(
            {
                "input": "dataset",
                "input_config": input_config
            }, 0, True)

        ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
        reader = DatasetReader(ioctx, dataset)
        assert len(reader.next()) == 1200
Exemple #7
0
    def test_dataset_shard_with_task_parallelization(self):
        """Tests whether the dataset_shard function works correctly with parallelism
        for reading the dataset."""
        config = {
            "input": "dataset",
            "input_config": {
                "format": "json",
                "paths": self.dset_path,
                "parallelism": 10,
            },
        }
        NUM_WORKERS = 4

        _, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)

        assert len(shards) == NUM_WORKERS + 1
        assert shards[0] is None
        assert all(
            isinstance(remote_shard, ray.data.Dataset)
            for remote_shard in shards[1:])
Exemple #8
0
    def test_dataset_shard_remote_workers_with_local_worker(self):
        """Tests whether the dataset_shard function works correctly for the remote
        workers with a dummy dataset shard for the local worker."""

        config = {
            "input": "dataset",
            "input_config": {
                "format": "json",
                "paths": self.dset_path
            },
        }
        NUM_WORKERS = 4

        _, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)

        assert len(shards) == NUM_WORKERS + 1
        assert shards[0] is None
        assert all(
            isinstance(remote_shard, ray.data.Dataset)
            for remote_shard in shards[1:])