Exemple #1
0
    def gen_dataset_from_path(
        self,
        path: str,
        rank: int = 0,
        world_size: int = 1,
        include_label_fields: bool = True,
        use_cache: bool = True,
    ) -> textdata.Dataset:
        """
        Generate a dataset from file
        Returns:
            dataset (TorchText.Dataset)
        """
        if use_cache and path in self._data_cache and rank == 0 and world_size == 1:
            return self._data_cache[path]

        shard_range = (dist_utils.get_shard_range(
            self.metadata.dataset_sizes[path], rank, world_size)
                       if world_size > 1 else None)
        res = self.gen_dataset(
            self.read_from_file(path, self.raw_columns),
            include_label_fields,
            shard_range,
        )
        if rank == 0 and world_size == 1:
            self._data_cache[path] = res
        return res
Exemple #2
0
    def read_from_file(
        self,
        file_name: str,
        columns_to_use: Union[Dict[str, int], List[str]],
        rank: int = 0,
        world_size: int = 1,
    ) -> List[Dict[str, Any]]:
        """
        Read data from csv file. Input file format is required to be
        tab-separated columns

        Args:
            file_name (str): csv file name
            columns_to_use (Union[Dict[str, int], List[str]]): either a list of
                column names or a dict of column name -> column index in the file
        """
        print("reading data from {}".format(file_name))
        if isinstance(columns_to_use, list):
            columns_to_use = {
                name: idx
                for name, idx in zip(columns_to_use, range(len(
                    columns_to_use)))
            }
        shard_range = (dist_utils.get_shard_range(
            self.metadata.dataset_sizes[file_name], rank, world_size)
                       if world_size > 1 else None)

        with open(file_name, "r", encoding="utf-8",
                  errors="replace") as f_handle:
            csv_reader = csv.reader(f_handle,
                                    delimiter="\t",
                                    quoting=csv.QUOTE_NONE)
            data = []
            i, row_idx = 0, 0
            while True:
                i += 1
                try:
                    row = next(csv_reader)
                except csv.Error:
                    print("ignoring line {}".format(i))
                    continue
                except StopIteration:
                    break

                if not shard_range or shard_range[0] <= row_idx < shard_range[
                        1]:
                    data.append({
                        name: row[index] if index < len(row) else ""
                        for name, index in columns_to_use.items()
                    })
                row_idx += 1

            # some shard might have 1 less example due to data_size % world_size
            # pad the shard to make sure all shard dataset have same size
            dist_utils.pad_shard_data(data, row_idx, world_size)
            return data
Exemple #3
0
 def get_train_iter_from_raw_data(
     self,
     train_data: List[Dict[str, Any]],
     batch_size: int,
     rank: int = 0,
     world_size: int = 1,
 ) -> BatchIterator:
     shard_range = dist_utils.get_shard_range(len(train_data), rank,
                                              world_size)
     shard_data = train_data[shard_range[0]:shard_range[1]]
     dist_utils.pad_shard_data(shard_data, len(train_data), world_size)
     return self._get_train_iter(self.gen_dataset(shard_data), batch_size,
                                 world_size)
Exemple #4
0
    def test_get_shard_range(self):
        # first 5 ranks should take 3 examples
        # last 3 ranks should take 2 examples, but to make sure all shard have
        # same size, we pad with the previous example.
        dataset_size, world_size = 21, 8
        expected = [
            (0, (0, 2)),
            (1, (3, 5)),
            (2, (6, 8)),
            (3, (9, 11)),
            (4, (12, 14)),
            (5, (14, 16)),
            (6, (16, 18)),
            (7, (18, 20)),
        ]
        for rank, expected_range in expected:
            shard_range = get_shard_range(dataset_size, rank, world_size)
            self.assertEqual(shard_range, expected_range)

        dataset_size, world_size = 16, 4
        expected = [(0, (0, 3)), (1, (4, 7)), (2, (8, 11)), (3, (12, 15))]
        for rank, expected_range in expected:
            shard_range = get_shard_range(dataset_size, rank, world_size)
            self.assertEqual(shard_range, expected_range)