Example #1
0
    def gen_dataset_from_path(
        self,
        path: str,
        rank: int = 0,
        world_size: int = 1,
        include_label_fields: bool = True,
        use_cache: bool = True,
    ) -> textdata.Dataset:
        """
        Generate a dataset from file
        Returns:
            dataset (TorchText.Dataset)
        """
        if use_cache and path in self._data_cache and rank == 0 and world_size == 1:
            return self._data_cache[path]

        shard_range = (distributed.get_shard_range(
            self.metadata.dataset_sizes[path], rank, world_size)
                       if world_size > 1 else None)
        res = self.gen_dataset(
            self.read_from_file(path, self.raw_columns),
            include_label_fields,
            shard_range,
        )
        if rank == 0 and world_size == 1:
            self._data_cache[path] = res
        return res
Example #2
0
 def get_train_iter_from_raw_data(
     self,
     train_data: List[Dict[str, Any]],
     batch_size: int,
     rank: int = 0,
     world_size: int = 1,
 ) -> BatchIterator:
     shard_range = distributed.get_shard_range(len(train_data), rank, world_size)
     return self._get_train_iter(
         self.gen_dataset(train_data, shard_range=shard_range),
         batch_size,
         world_size,
     )
Example #3
0
    def test_get_shard_range(self):
        # first 5 ranks should take 3 examples
        # last 3 ranks should take 2 examples, but to make sure all shard have
        # same size, we pad with the previous example.
        dataset_size, world_size = 21, 8
        expected = [
            (0, (0, 2)),
            (1, (3, 5)),
            (2, (6, 8)),
            (3, (9, 11)),
            (4, (12, 14)),
            (5, (14, 16)),
            (6, (16, 18)),
            (7, (18, 20)),
        ]
        for rank, expected_range in expected:
            shard_range = get_shard_range(dataset_size, rank, world_size)
            self.assertEqual(shard_range, expected_range)

        dataset_size, world_size = 16, 4
        expected = [(0, (0, 3)), (1, (4, 7)), (2, (8, 11)), (3, (12, 15))]
        for rank, expected_range in expected:
            shard_range = get_shard_range(dataset_size, rank, world_size)
            self.assertEqual(shard_range, expected_range)