Exemplo n.º 1
0
def _wif(worker_id):
    set_num_threads(1)
    info = get_worker_info()
    ds = info.dataset.d
    ds.num_workers, ds.offs = info.num_workers, info.id
    set_seed(info.seed)
    ds.wif()
Exemplo n.º 2
0
def _capture_metadata_collate(samples: List, dataset: Dataset,
                              collate_fn: Callable,
                              fault_tolerant_mode: _FaultTolerantMode) -> Any:
    """A collate_fn function that adds the state dict of a :class:`CaptureIterableDataset` or
    :class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker
    processes. The structure will be:

    .. code-block:: python

        {
            "data": ...,  # data returned by Dataset
            "__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
        }
    """
    data = collate_fn(samples)
    metadata = None
    if fault_tolerant_mode.is_automatic:
        metadata = dataset.state_dict()
    else:
        state_dict_fn = getattr(dataset, "state_dict", None)
        info = get_worker_info()
        worker_id = info.id if info else 0
        if state_dict_fn is not None:
            metadata = state_dict_fn()
            if worker_id not in metadata:
                if info and info.num_workers > 1:
                    raise MisconfigurationException(
                        f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys."
                    )
                metadata = {0: metadata}
        if metadata is None:
            metadata = {worker_id: {}}

    return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}
 def shuffle(self, randomize=True, seed=42):
     worker_info = get_worker_info()
     worker_id = 0 if worker_info is None else worker_info.id
     if worker_id == 0:
         for _ in self.datasets:
             print(f"shuffling {_}")
             _.shuffle(randomize, seed)
Exemplo n.º 4
0
    def __iter__(self):
        worker_info = data.get_worker_info()
        if worker_info is None:
            seed = self.seed
        else:
            seed = (self.seed * (worker_info.id + 1)) % (2**32 - 1)
        rng = np.random.RandomState(seed)

        class InternalIterator:
            def __init__(self, parent: Dataset):
                self.parent = parent
                self.obj_prob = \
                    [float(n) for n in range(self.parent.min_object,
                                             self.parent.max_object + 1)]
                self.obj_prob = \
                    [p / sum(self.obj_prob) - 1e-5 for p in self.obj_prob]

            def __next__(self) -> Any:
                n_object = rng.multinomial(1, self.obj_prob).nonzero()[0] + 1
                ast = self.parent.sample_ast(rng, n_object)
                retval = Environment({"ground_truth": ast},
                                     set(["ground_truth"]))
                return retval

        return InternalIterator(self)
Exemplo n.º 5
0
    def __iter__(self):
        worker_info = get_worker_info()

        # Only divide up batches when using multiple worker processes
        if worker_info != None:
            batches = list(self.dataset.to_batches())
            worker_load = len(batches) // worker_info.num_workers

            # If more workers than batches exist, some won't be used
            if worker_load == 0:
                if worker_info.id < len(batches):
                    self.batches = [batches[worker_info.id]]
                else:
                    return
            else:
                start = worker_load * worker_info.id
                end = min(start + worker_load, len(batches))
                self.batches = batches[start:end]
        else:
            self.batches = self.dataset.to_batches()

        # Process and yield each batch
        for batch in self.batches:
            batch = batch.to_pydict()
            batch.update(self.process_func(batch))

            yield batch
Exemplo n.º 6
0
 def __iter__(self):
     worker_info = get_worker_info()
     if worker_info is not None:  # multi-process case
         split_sample_ids = np.array_split(self.sample_ids,
                                           worker_info.num_workers)
         self.sample_ids = split_sample_ids[worker_info.id]
     return iter(self.one_epoch())
Exemplo n.º 7
0
    def __init__(self,
                 archive,
                 transform=to_tensor,
                 extensions=('.png', '.jpg', '.jpeg'),
                 is_valid_file=None):
        if not isinstance(archive, TarDataset):
            # open tar file. in a multiprocessing setting (e.g. DataLoader workers), we
            # have to open one file handle per worker (stored as the tar_obj dict), since
            # when the multiprocessing method is 'fork', the workers share this TarDataset.
            # we want one file handle per worker because TarFile is not thread-safe.
            worker = get_worker_info()
            worker = worker.id if worker else None
            self.tar_obj = {worker: tarfile.open(archive)}
            self.archive = archive

            # store headers of all files and folders by name
            members = sorted(self.tar_obj[worker].getmembers(),
                             key=lambda m: m.name)
            self.members_by_name = {m.name: m for m in members}
        else:
            # passed a TarDataset into the constructor, reuse the same tar contents.
            # no need to copy explicitly since this dict will not be modified again.
            self.members_by_name = archive.members_by_name
            self.archive = archive.archive  # the original path to the Tar file
            self.tar_obj = {}  # will get filled by get_file on first access

        # also store references to the iterated samples (a subset of the above)
        self.filter_samples(is_valid_file, extensions)

        self.transform = transform
Exemplo n.º 8
0
    def __iter__(self):
        lines = it.chain.from_iterable(
            Path(x).read_text(encoding="utf-8").splitlines()
            for x in self.file_paths)

        # Split data when there are multiple workers running in parallel
        worker_info = get_worker_info()
        if worker_info is not None:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers
            lines = it.islice(lines, worker_id, None, num_workers)

        chunks = batch(lines, self.chunk_size)
        tokenized_chunks = (
            t for chunk in chunks
            for tokens in self.tokenizer.encode_batch(list(chunk))
            for t in tokens.ids[1:-1])
        blocks = (list(x) for x in batch(tokenized_chunks, self.block_size))

        # Add class and separation tokens to each block
        cls_token, sep_token = self.tokenizer.encode('').ids
        blocks_with_special_tokens = ([cls_token] + block + [sep_token]
                                      for block in blocks)
        return (torch.tensor(x, dtype=torch.long)
                for x in blocks_with_special_tokens)
Exemplo n.º 9
0
 def __iter__(self):
     worker_info = data.get_worker_info()
     if worker_info is not None:
         if worker_info.num_workers > 1:
             raise ValueError(
                 'Patches must be sequential for the saver to reconstruct the image hence num_workers must be 0 or 1')
     self.patch_index = 0
     return self
Exemplo n.º 10
0
def get_worker():
    info = get_worker_info()
    if not info:
        worker = 0
    else:
        worker = info.id
    assert worker >= 0
    return worker
Exemplo n.º 11
0
def _shard_iterator_dataloader_worker(iterable):
    # Shard the iterable if we're currently inside pytorch dataloader worker.
    worker_info = data.get_worker_info()
    if worker_info is None or worker_info.num_workers == 1:
        # do nothing
        yield from iterable
    else:
        yield from itertools.islice(iterable, worker_info.id, None, worker_info.num_workers)
Exemplo n.º 12
0
def worker_init_fn(worker_id):
    worker_info = data.get_worker_info()
    dataset = worker_info.dataset
    sample_i = int(ceil(dataset.num_samples / float(worker_info.num_workers)))
    if worker_id == (worker_info.num_workers - 1):
        sample_i = int(dataset.num_samples - worker_id * sample_i)
    dataset.num_samples = sample_i
    dataset.set_seed(worker_id + dataset.seed)
Exemplo n.º 13
0
def setValidationWorker(worker_id):
    worker_info = data.get_worker_info()
    if worker_info:
        dataset = worker_info.dataset
        dataset.num_workers = worker_info.num_workers
        dataset.cur_file_num = worker_info.id % dataset.patient_len
        dataset.cur_pkl, dataset.cur_pkl_shape = dataset.loadFile(
            dataset.cur_file_num)
Exemplo n.º 14
0
    def __iter__(self) -> Iterable[Sample]:
        worker_info = tdata.get_worker_info()
        if worker_info:
            step_size = worker_info.num_workers
            worker_id = worker_info.id
        else:
            step_size = 1
            worker_id = 0

        # Split trajectories across workers
        trajectories = self._shuffle_trajectories()
        for i in range(worker_id, len(trajectories), step_size):
            traj_dir = trajectories[i]
            observations = self._load_observations(traj_dir)
            if self._skip_failed and \
                    not self._check_success(observations['reward']):
                continue

            sampled_steps = self._get_steps(observations['reward'])

            a_pad, actions = self._parse_actions(observations)
            n_steps = actions.shape[0]
            i_pad, inv = self._standardize_inv(self._parse_inv(observations))
            assert n_steps == inv.shape[0], traj_dir
            f_pad, frames = self._load_frames(traj_dir)
            if frames.shape[0] != n_steps:
                m = 'Num frames: {} num steps: {} trajectory: {}'
                logging.debug(m.format(frames.shape[0], n_steps, traj_dir))
            # Sometimes an episode starts after a number of frames have been
            # recorded.
            frames = frames[-n_steps:]

            msg = ('Worker {} loaded trajectory: {}. Total steps: {} '
                   'sampled steps: {}')
            logging.debug(
                msg.format(worker_id, traj_dir.name, n_steps,
                           sampled_steps.size))

            for j in sampled_steps:
                s_inv = {}
                s_actions = {}
                s_frames = {}
                for ri in self._relative_indices:
                    if ri > j:
                        s_inv[ri] = i_pad
                        s_actions[ri] = a_pad
                        s_frames[ri] = f_pad
                    else:
                        s_inv[ri] = inv[j - ri]
                        s_actions[ri] = actions[j - ri]
                        s_frames[ri] = frames[j - ri]
                yield Sample(inv=s_inv,
                             actions=s_actions,
                             frames=s_frames,
                             trajectory=traj_dir.name)

            gc.collect()
        self._epoch += 1
Exemplo n.º 15
0
 def __iter__(self):
     worker_info = get_worker_info()
     assert worker_info is not None
     per_worker = len(self.seeds) // worker_info.num_workers
     worker_id = worker_info.id
     st = worker_id * per_worker
     en = min((worker_id + 1) * per_worker, len(self.seeds))
     yield from self.post(
         self.fn(self.graph, self.seeds[st:en], **self.kwargs))
 def __iter__(self):
     worker_info = get_worker_info()
     if worker_info is None:  # single-process data loading, return the full iterator
         size = self.size
     else:  # in a worker process
         # split workload
         size = int(self.size / float(worker_info.num_workers))
     datalist = self.generate_n_data(size)
     return datalist
Exemplo n.º 17
0
 def __iter__(self):
     worker_info = get_worker_info()
     buffer = []
     for i, element in enumerate(self.elements):
         if worker_info is not None:
             if worker_info.id != i % worker_info.num_workers:
                 continue
         buffer.append(element)
     yield from self.gen_record(buffer)
Exemplo n.º 18
0
def cat_into_shared_memory(batch: List[torch.Tensor]):
    out = None
    elem = batch[0]
    if get_worker_info() is not None:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum([x.numel() for x in batch])
        storage = elem.storage()._new_shared(numel)
        out = elem.new(storage)
    return torch.stack(batch, 0, out=out)
Exemplo n.º 19
0
 def __getitem__(self, item):
     pos = self.positions[item]
     worker_info = data.get_worker_info()
     if worker_info is None:
         f = self.f
     else:
         f = self.files[worker_info.id]
     f.seek(pos, 0)
     size = int.from_bytes(f.read(4), 'little', signed=False)
     return Image.open(io.BytesIO(f.read(size)))
Exemplo n.º 20
0
 def __iter__(self):
     items = copy(self.items)
     if self.max_epochs:
         shuffle(items)
         items = items[:self.max_epochs]
     info = get_worker_info()
     items = copy(
         items if info is None else items[info.id::info.num_workers])
     shuffle(items)
     return VideoIterator(items, self.directory, self.aug)
Exemplo n.º 21
0
    def _get_generator(self):
        worker_info: Optional[Any] = get_worker_info()

        if worker_info is None:
            return self._dataset.stream_split(self._split, 0, 1)
        else:
            num_workers: int = worker_info.num_workers
            worker_id: int = worker_info.id

            return self._dataset.stream_split(self._split, worker_id,
                                              num_workers)
Exemplo n.º 22
0
 def __iter__(self):
     worker_info = get_worker_info()
     if worker_info is None:
         fileslist = self.fileslist
     else:
         fnum = len(self.fileslist) // worker_info.num_workers
         id = worker_info.id
         fileslist = self.fileslist[fnum * id:fnum * (id + 1)]
     return chain.from_iterable((self.dataset_class(*names,
                                                    **self.dataset_params)
                                 for names in fileslist))
Exemplo n.º 23
0
    def __iter__(self):
        info = get_worker_info()
        num_workers = info.num_workers if info is not None else 1
        id = info.id if info is not None else 0

        self.source = iter(self.data)
        for i, item in enumerate(self.source):
            if i % num_workers == id:
                if self.transform is not None:
                    item = apply_transform(self.transform, item)
                yield item
 def __iter__(self):
     worker_info = get_worker_info()
     worker_id = 0 if worker_info is None else worker_info.id
     self.iterables = []
     for ds_length, dataset in zip(self.ds_lengths, self.datasets):
         start = (worker_id * self.per_worker) % ds_length
         self.iterables.append(
             cycle(
                 chain(islice(iter(dataset), start, None),
                       islice(iter(dataset), start))))
     return self
Exemplo n.º 25
0
def worker_init_fn(worker_id):
    worker_info = get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process
    worker_id = worker_info.id
    overall_start = 0
    overall_end = len(dataset.file_list)

    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    dataset.start = overall_start + worker_id * per_worker
    dataset.end = min(dataset.start + per_worker, overall_end)
Exemplo n.º 26
0
 def _rand(self) -> np.random.RandomState:
     # Random state should be initialized in worker process
     if self._rand_h is None:
         worker_info = tdata.get_worker_info()
         if worker_info:
             # Numpy requires int32 seeds, but Pytorch sometimes sets int64
             seed = worker_info.seed % 1000003
         else:
             seed = self._conf_seed
         self._rand_h = np.random.RandomState(seed)
     return self._rand_h
Exemplo n.º 27
0
 def _init_worker(self):
     worker_info = get_worker_info()
     if worker_info is None:
         num_workers_per_rank = 1
         worker_rank = 0
     else:
         num_workers_per_rank = worker_info.num_workers
         worker_rank = worker_info.id
     assert (len(self._files) %
             (self._world_size * num_workers_per_rank) == 0)
     self._logger.init_for_worker(worker_rank)
     return worker_rank, num_workers_per_rank
Exemplo n.º 28
0
        def fn(worker_id):
            worker_info = get_worker_info()

            if worker_info is None:
                raise RuntimeError(
                    "Custom initialization should be used for multiprocessing "
                    "only.")

            # pylint: disable=no-member
            dataset = worker_info.dataset
            dataset._consumer = cls.new_consumer(*args, **kwargs)
            dataset._worker_id = worker_id
def allennlp_worker_init_fn(worker_id):
    """
    The default worker init function used by [`PyTorchDataLoader`](#pytorchdataloader).

    This is needed when using `num_workers > 0` so that each worker process knows which
    instances it's responsible for.
    """
    worker_info = data.get_worker_info()
    dataset = worker_info.dataset
    if isinstance(dataset, AllennlpLazyDataset):
        dataset.reader._set_worker_info(
            WorkerInfo(worker_info.num_workers, worker_id))
    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is not None:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers

            split_size = (len(self.query_paper_ids) // num_workers) + 1
            self.length //= num_workers
            self.query_paper_ids = self.query_paper_ids[worker_id * split_size: (worker_id + 1) * split_size]
        
        self.triplet_generator = self._build_triplet_generator()
        return self