def process_predict_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: data_loader = import_module(self.model_type).infer_dl( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, ) data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) input_transform = input_transform or self.input_transform if input_transform is not None: input_transform.inject_collate_fn(data_loader.collate_fn) data_loader.collate_fn = create_worker_input_transform_processor( RunningStage.PREDICTING, input_transform) return data_loader
def process_predict_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: input_transform = input_transform or self.input_transform collate_fn = self.collate_fn if input_transform is not None: # Inject the `self.collate_fn` input_transform.inject_collate_fn(self.collate_fn) collate_fn = create_worker_input_transform_processor( RunningStage.PREDICTING, input_transform) return DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, collate_fn=collate_fn, persistent_workers=persistent_workers, )
def _predict_dataloader(self) -> DataLoader: predict_ds: Input = self._predict_input input_transform = self._resolve_input_transform() if isinstance(predict_ds, IterableDataset): batch_size = self.batch_size else: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) if isinstance(getattr(self, "trainer", None), pl.Trainer) and hasattr( self.trainer.lightning_module, "process_predict_dataset"): dataloader = self.trainer.lightning_module.process_predict_dataset( predict_ds, self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, input_transform=input_transform, trainer=self.trainer, ) else: dataloader = DataLoader( predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=create_worker_input_transform_processor( RunningStage.PREDICTING, input_transform), persistent_workers=self.persistent_workers, ) self._on_after_batch_transfer_fns = None return dataloader
def _test_dataloader(self) -> DataLoader: test_ds: Input = self._test_input input_transform = self._resolve_input_transform() if isinstance(getattr(self, "trainer", None), pl.Trainer) and hasattr( self.trainer.lightning_module, "process_test_dataset"): dataloader = self.trainer.lightning_module.process_test_dataset( test_ds, self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, input_transform=input_transform, trainer=self.trainer, ) else: dataloader = DataLoader( test_ds, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=create_worker_input_transform_processor( RunningStage.TESTING, input_transform), persistent_workers=self.persistent_workers, ) self._on_after_batch_transfer_fns = None return dataloader
def predict_dataloader(self) -> "DataLoader": self.labelled._predict_input = self.filter_unlabelled_data(self._dataset.pool) dataloader = self.labelled._predict_dataloader() dataloader.collate_fn = create_worker_input_transform_processor( RunningStage.TRAINING, self.labelled.input_transform ) return dataloader
def _val_dataloader(self) -> "DataLoader": self.labelled._val_input = train_val_split(self._dataset, self.val_split)[1] dataloader = self.labelled._val_dataloader() dataloader.collate_fn = create_worker_input_transform_processor( RunningStage.TRAINING, self.labelled.input_transform ) return dataloader
def _train_dataloader(self) -> DataLoader: train_ds: Input = self._train_input input_transform = self._resolve_input_transform() shuffle: bool = False if isinstance(train_ds, IterableDataset): drop_last = False else: drop_last = len(train_ds) > self.batch_size if self.sampler is None: sampler = None shuffle = not isinstance(train_ds, IterableDataset) elif callable(self.sampler): sampler = self.sampler(train_ds) else: sampler = self.sampler if isinstance(getattr(self, "trainer", None), pl.Trainer) and hasattr( self.trainer.lightning_module, "process_train_dataset"): dataloader = self.trainer.lightning_module.process_train_dataset( train_ds, self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=self.persistent_workers, input_transform=input_transform, trainer=self.trainer, ) else: dataloader = DataLoader( train_ds, batch_size=self.batch_size, shuffle=shuffle, sampler=sampler, num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=drop_last, collate_fn=create_worker_input_transform_processor( RunningStage.TRAINING, input_transform), persistent_workers=self.persistent_workers, ) self._on_after_batch_transfer_fns = None return dataloader