Esempio n. 1
0
    def get_loader(self,
                   mode: str,
                   epoch: int = 1,
                   shuffle: Optional[bool] = None,
                   output_keys: Optional[Set[str]] = None) -> Union[DataLoader, tf.data.Dataset]:
        """Get a data loader from the Pipeline for a given `mode` and `epoch`.

        Args:
            mode: The execution mode for the loader. This can be 'train', 'eval' or 'test'.
            epoch: The epoch index for the loader. Note that epoch indices are 1-indexed.
            shuffle: Whether to shuffle the data. If None, the value for shuffle is based on mode. NOTE: This argument
                is only used with FastEstimator Datasets.
            output_keys: What keys can be produced from pipeline. If None, all keys will be considered.

        Returns:
            A data loader for the given `mode` and `epoch`.
        """
        data = self.data[mode]
        if isinstance(data, Scheduler):
            data = data.get_current_value(epoch)
        if isinstance(data, Dataset):
            # batch size
            batch_size = self.batch_size
            if isinstance(batch_size, Scheduler):
                batch_size = batch_size.get_current_value(epoch)
            if isinstance(batch_size, dict):
                batch_size = batch_size[mode]
            # batch dataset
            if isinstance(data, BatchDataset):
                data.pad_value = self.pad_value
            # shuffle
            if shuffle is None:
                shuffle = mode == "train" and batch_size is not None
            # collate_fn
            collate_fn = self.collate_fn
            if collate_fn is None and self.pad_value is not None:
                collate_fn = self._pad_batch_collate
            op_dataset = OpDataset(data,
                                   get_current_items(self.ops, mode, epoch),
                                   mode,
                                   output_keys,
                                   deep_remainder=False)
            # Results will be immediately converted to tensors, so don't need deep_remainder
            batch_size = None if isinstance(data, BatchDataset) else batch_size
            data = DataLoader(op_dataset,
                              batch_size=batch_size,
                              shuffle=False if isinstance(data, BatchDataset) else shuffle,
                              sampler=RandomSampler(op_dataset) if isinstance(data, BatchDataset) and shuffle else None,
                              num_workers=self.num_process,
                              drop_last=False if batch_size is None else self.drop_last,
                              worker_init_fn=lambda _: np.random.seed(random.randint(0, 2**32 - 1)),
                              collate_fn=collate_fn)
        return data
Esempio n. 2
0
    def __enter__(self) -> Union[DataLoader, tf.data.Dataset]:
        """Get a data loader from the Pipeline for the current epoch and mode.

        A given pipeline can only provide one loader at a time. This helps to prevent issues with multi-threading.

        ```python
        pipe = Pipeline(...)
        with pipe(mode='eval', epoch=2) as loader:
            for batch in loader:
                print(batch)
        ```

        Returns:
            A data loader for the current `mode` and `epoch`.

        Raises:
            ValueError: If called while the pipeline already has an active loader.
        """
        acquired = self.ctx_lock.acquire(blocking=False)
        if not acquired:
            raise ValueError(
                "You cannot generate a new loader from this Pipeline before closing its other loader."
            )
        # Release the lock if arguments are invalid so that people in Jupyter / debug consoles don't get stuck
        if self.ctx_mode not in self.data:
            self.ctx_lock.release()
            raise KeyError(f"Pipeline has no data for mode '{self.ctx_mode}'")
        if self.ctx_ds_id not in self.data[self.ctx_mode]:
            self.ctx_lock.release()
            raise KeyError(
                f"The dataset id '{self.ctx_ds_id}' is not present in {self.ctx_mode} mode"
            )
        data = self.data[self.ctx_mode][self.ctx_ds_id]
        if isinstance(data, Scheduler):
            data = data.get_current_value(self.ctx_epoch)
        if isinstance(data, Dataset):
            # Results will be immediately converted to tensors, so don't need deep_remainder
            op_dataset = OpDataset(
                data,
                self.ctx_ops,
                self.ctx_mode,
                self.ctx_output_keys
                | self.ctx_batch_input_keys if self.ctx_output_keys else None,
                deep_remainder=False)
            # check whether to batch the data
            batch_size = None if op_dataset.fe_batch else self.ctx_batch_size
            # Figure out whether a postprocessing function is needed (for batched ops)
            postprocess_fn = None
            if self.ctx_batch_ops:
                postprocess_fn = functools.partial(
                    _batch_postprocess,
                    ops=self.ctx_batch_ops,
                    output_keys=self.ctx_output_keys,
                    mode=self.ctx_mode)
            try:
                data = FEDataLoader(op_dataset,
                                    postprocess_fn=postprocess_fn,
                                    batch_size=batch_size,
                                    shuffle=self.ctx_shuffle,
                                    steps_per_epoch=self.ctx_steps_per_epoch,
                                    num_workers=self.num_process,
                                    drop_last=self.ctx_batch_info.drop_last,
                                    collate_fn=self.ctx_batch_info.collate_fn)
            except ValueError as err:
                self.ctx_lock.release()
                raise err
            self.ctx_loader = data
        return data