def get_loader(self, mode: str, epoch: int = 1, shuffle: Optional[bool] = None, output_keys: Optional[Set[str]] = None) -> Union[DataLoader, tf.data.Dataset]: """Get a data loader from the Pipeline for a given `mode` and `epoch`. Args: mode: The execution mode for the loader. This can be 'train', 'eval' or 'test'. epoch: The epoch index for the loader. Note that epoch indices are 1-indexed. shuffle: Whether to shuffle the data. If None, the value for shuffle is based on mode. NOTE: This argument is only used with FastEstimator Datasets. output_keys: What keys can be produced from pipeline. If None, all keys will be considered. Returns: A data loader for the given `mode` and `epoch`. """ data = self.data[mode] if isinstance(data, Scheduler): data = data.get_current_value(epoch) if isinstance(data, Dataset): # batch size batch_size = self.batch_size if isinstance(batch_size, Scheduler): batch_size = batch_size.get_current_value(epoch) if isinstance(batch_size, dict): batch_size = batch_size[mode] # batch dataset if isinstance(data, BatchDataset): data.pad_value = self.pad_value # shuffle if shuffle is None: shuffle = mode == "train" and batch_size is not None # collate_fn collate_fn = self.collate_fn if collate_fn is None and self.pad_value is not None: collate_fn = self._pad_batch_collate op_dataset = OpDataset(data, get_current_items(self.ops, mode, epoch), mode, output_keys, deep_remainder=False) # Results will be immediately converted to tensors, so don't need deep_remainder batch_size = None if isinstance(data, BatchDataset) else batch_size data = DataLoader(op_dataset, batch_size=batch_size, shuffle=False if isinstance(data, BatchDataset) else shuffle, sampler=RandomSampler(op_dataset) if isinstance(data, BatchDataset) and shuffle else None, num_workers=self.num_process, drop_last=False if batch_size is None else self.drop_last, worker_init_fn=lambda _: np.random.seed(random.randint(0, 2**32 - 1)), collate_fn=collate_fn) return data
def __enter__(self) -> Union[DataLoader, tf.data.Dataset]: """Get a data loader from the Pipeline for the current epoch and mode. A given pipeline can only provide one loader at a time. This helps to prevent issues with multi-threading. ```python pipe = Pipeline(...) with pipe(mode='eval', epoch=2) as loader: for batch in loader: print(batch) ``` Returns: A data loader for the current `mode` and `epoch`. Raises: ValueError: If called while the pipeline already has an active loader. """ acquired = self.ctx_lock.acquire(blocking=False) if not acquired: raise ValueError( "You cannot generate a new loader from this Pipeline before closing its other loader." ) # Release the lock if arguments are invalid so that people in Jupyter / debug consoles don't get stuck if self.ctx_mode not in self.data: self.ctx_lock.release() raise KeyError(f"Pipeline has no data for mode '{self.ctx_mode}'") if self.ctx_ds_id not in self.data[self.ctx_mode]: self.ctx_lock.release() raise KeyError( f"The dataset id '{self.ctx_ds_id}' is not present in {self.ctx_mode} mode" ) data = self.data[self.ctx_mode][self.ctx_ds_id] if isinstance(data, Scheduler): data = data.get_current_value(self.ctx_epoch) if isinstance(data, Dataset): # Results will be immediately converted to tensors, so don't need deep_remainder op_dataset = OpDataset( data, self.ctx_ops, self.ctx_mode, self.ctx_output_keys | self.ctx_batch_input_keys if self.ctx_output_keys else None, deep_remainder=False) # check whether to batch the data batch_size = None if op_dataset.fe_batch else self.ctx_batch_size # Figure out whether a postprocessing function is needed (for batched ops) postprocess_fn = None if self.ctx_batch_ops: postprocess_fn = functools.partial( _batch_postprocess, ops=self.ctx_batch_ops, output_keys=self.ctx_output_keys, mode=self.ctx_mode) try: data = FEDataLoader(op_dataset, postprocess_fn=postprocess_fn, batch_size=batch_size, shuffle=self.ctx_shuffle, steps_per_epoch=self.ctx_steps_per_epoch, num_workers=self.num_process, drop_last=self.ctx_batch_info.drop_last, collate_fn=self.ctx_batch_info.collate_fn) except ValueError as err: self.ctx_lock.release() raise err self.ctx_loader = data return data