def gen_dataset_from_path( self, path: str, rank: int = 0, world_size: int = 1, include_label_fields: bool = True, use_cache: bool = True, ) -> textdata.Dataset: """ Generate a dataset from file Returns: dataset (TorchText.Dataset) """ if use_cache and path in self._data_cache and rank == 0 and world_size == 1: return self._data_cache[path] shard_range = (dist_utils.get_shard_range( self.metadata.dataset_sizes[path], rank, world_size) if world_size > 1 else None) res = self.gen_dataset( self.read_from_file(path, self.raw_columns), include_label_fields, shard_range, ) if rank == 0 and world_size == 1: self._data_cache[path] = res return res
def read_from_file( self, file_name: str, columns_to_use: Union[Dict[str, int], List[str]], rank: int = 0, world_size: int = 1, ) -> List[Dict[str, Any]]: """ Read data from csv file. Input file format is required to be tab-separated columns Args: file_name (str): csv file name columns_to_use (Union[Dict[str, int], List[str]]): either a list of column names or a dict of column name -> column index in the file """ print("reading data from {}".format(file_name)) if isinstance(columns_to_use, list): columns_to_use = { name: idx for name, idx in zip(columns_to_use, range(len( columns_to_use))) } shard_range = (dist_utils.get_shard_range( self.metadata.dataset_sizes[file_name], rank, world_size) if world_size > 1 else None) with open(file_name, "r", encoding="utf-8", errors="replace") as f_handle: csv_reader = csv.reader(f_handle, delimiter="\t", quoting=csv.QUOTE_NONE) data = [] i, row_idx = 0, 0 while True: i += 1 try: row = next(csv_reader) except csv.Error: print("ignoring line {}".format(i)) continue except StopIteration: break if not shard_range or shard_range[0] <= row_idx < shard_range[ 1]: data.append({ name: row[index] if index < len(row) else "" for name, index in columns_to_use.items() }) row_idx += 1 # some shard might have 1 less example due to data_size % world_size # pad the shard to make sure all shard dataset have same size dist_utils.pad_shard_data(data, row_idx, world_size) return data
def get_train_iter_from_raw_data( self, train_data: List[Dict[str, Any]], batch_size: int, rank: int = 0, world_size: int = 1, ) -> BatchIterator: shard_range = dist_utils.get_shard_range(len(train_data), rank, world_size) shard_data = train_data[shard_range[0]:shard_range[1]] dist_utils.pad_shard_data(shard_data, len(train_data), world_size) return self._get_train_iter(self.gen_dataset(shard_data), batch_size, world_size)
def test_get_shard_range(self): # first 5 ranks should take 3 examples # last 3 ranks should take 2 examples, but to make sure all shard have # same size, we pad with the previous example. dataset_size, world_size = 21, 8 expected = [ (0, (0, 2)), (1, (3, 5)), (2, (6, 8)), (3, (9, 11)), (4, (12, 14)), (5, (14, 16)), (6, (16, 18)), (7, (18, 20)), ] for rank, expected_range in expected: shard_range = get_shard_range(dataset_size, rank, world_size) self.assertEqual(shard_range, expected_range) dataset_size, world_size = 16, 4 expected = [(0, (0, 3)), (1, (4, 7)), (2, (8, 11)), (3, (12, 15))] for rank, expected_range in expected: shard_range = get_shard_range(dataset_size, rank, world_size) self.assertEqual(shard_range, expected_range)