def gen_dataset_from_path( self, path: str, rank: int = 0, world_size: int = 1, include_label_fields: bool = True, use_cache: bool = True, ) -> textdata.Dataset: """ Generate a dataset from file Returns: dataset (TorchText.Dataset) """ if use_cache and path in self._data_cache and rank == 0 and world_size == 1: return self._data_cache[path] shard_range = (distributed.get_shard_range( self.metadata.dataset_sizes[path], rank, world_size) if world_size > 1 else None) res = self.gen_dataset( self.read_from_file(path, self.raw_columns), include_label_fields, shard_range, ) if rank == 0 and world_size == 1: self._data_cache[path] = res return res
def get_train_iter_from_raw_data( self, train_data: List[Dict[str, Any]], batch_size: int, rank: int = 0, world_size: int = 1, ) -> BatchIterator: shard_range = distributed.get_shard_range(len(train_data), rank, world_size) return self._get_train_iter( self.gen_dataset(train_data, shard_range=shard_range), batch_size, world_size, )
def test_get_shard_range(self): # first 5 ranks should take 3 examples # last 3 ranks should take 2 examples, but to make sure all shard have # same size, we pad with the previous example. dataset_size, world_size = 21, 8 expected = [ (0, (0, 2)), (1, (3, 5)), (2, (6, 8)), (3, (9, 11)), (4, (12, 14)), (5, (14, 16)), (6, (16, 18)), (7, (18, 20)), ] for rank, expected_range in expected: shard_range = get_shard_range(dataset_size, rank, world_size) self.assertEqual(shard_range, expected_range) dataset_size, world_size = 16, 4 expected = [(0, (0, 3)), (1, (4, 7)), (2, (8, 11)), (3, (12, 15))] for rank, expected_range in expected: shard_range = get_shard_range(dataset_size, rank, world_size) self.assertEqual(shard_range, expected_range)