Ejemplo n.º 1
0
    def process_predict_dataset(
        self,
        dataset: InputBase,
        batch_size: int,
        num_workers: int = 0,
        pin_memory: bool = False,
        shuffle: bool = False,
        drop_last: bool = False,
        sampler: Optional[Sampler] = None,
        persistent_workers: bool = False,
        input_transform: Optional[InputTransform] = None,
        trainer: Optional["flash.Trainer"] = None,
    ) -> DataLoader:
        data_loader = import_module(self.model_type).infer_dl(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            shuffle=shuffle,
            drop_last=drop_last,
            sampler=sampler,
            persistent_workers=persistent_workers,
        )

        data_loader.collate_fn = functools.partial(self._wrap_collate_fn,
                                                   data_loader.collate_fn)

        input_transform = input_transform or self.input_transform
        if input_transform is not None:
            input_transform.inject_collate_fn(data_loader.collate_fn)
            data_loader.collate_fn = create_worker_input_transform_processor(
                RunningStage.PREDICTING, input_transform)
        return data_loader
Ejemplo n.º 2
0
    def process_predict_dataset(
        self,
        dataset: InputBase,
        batch_size: int,
        num_workers: int = 0,
        pin_memory: bool = False,
        shuffle: bool = False,
        drop_last: bool = False,
        sampler: Optional[Sampler] = None,
        persistent_workers: bool = False,
        input_transform: Optional[InputTransform] = None,
        trainer: Optional["flash.Trainer"] = None,
    ) -> DataLoader:
        input_transform = input_transform or self.input_transform

        collate_fn = self.collate_fn
        if input_transform is not None:
            # Inject the `self.collate_fn`
            input_transform.inject_collate_fn(self.collate_fn)

            collate_fn = create_worker_input_transform_processor(
                RunningStage.PREDICTING, input_transform)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            shuffle=shuffle,
            drop_last=drop_last,
            sampler=sampler,
            collate_fn=collate_fn,
            persistent_workers=persistent_workers,
        )
Ejemplo n.º 3
0
    def _predict_dataloader(self) -> DataLoader:
        predict_ds: Input = self._predict_input

        input_transform = self._resolve_input_transform()

        if isinstance(predict_ds, IterableDataset):
            batch_size = self.batch_size
        else:
            batch_size = min(self.batch_size,
                             len(predict_ds) if len(predict_ds) > 0 else 1)

        if isinstance(getattr(self, "trainer", None), pl.Trainer) and hasattr(
                self.trainer.lightning_module, "process_predict_dataset"):
            dataloader = self.trainer.lightning_module.process_predict_dataset(
                predict_ds,
                self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                persistent_workers=self.persistent_workers,
                input_transform=input_transform,
                trainer=self.trainer,
            )
        else:
            dataloader = DataLoader(
                predict_ds,
                batch_size=batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                collate_fn=create_worker_input_transform_processor(
                    RunningStage.PREDICTING, input_transform),
                persistent_workers=self.persistent_workers,
            )

        self._on_after_batch_transfer_fns = None
        return dataloader
Ejemplo n.º 4
0
    def _test_dataloader(self) -> DataLoader:
        test_ds: Input = self._test_input

        input_transform = self._resolve_input_transform()

        if isinstance(getattr(self, "trainer", None), pl.Trainer) and hasattr(
                self.trainer.lightning_module, "process_test_dataset"):
            dataloader = self.trainer.lightning_module.process_test_dataset(
                test_ds,
                self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                persistent_workers=self.persistent_workers,
                input_transform=input_transform,
                trainer=self.trainer,
            )
        else:
            dataloader = DataLoader(
                test_ds,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                collate_fn=create_worker_input_transform_processor(
                    RunningStage.TESTING, input_transform),
                persistent_workers=self.persistent_workers,
            )

        self._on_after_batch_transfer_fns = None
        return dataloader
Ejemplo n.º 5
0
 def predict_dataloader(self) -> "DataLoader":
     self.labelled._predict_input = self.filter_unlabelled_data(self._dataset.pool)
     dataloader = self.labelled._predict_dataloader()
     dataloader.collate_fn = create_worker_input_transform_processor(
         RunningStage.TRAINING, self.labelled.input_transform
     )
     return dataloader
Ejemplo n.º 6
0
 def _val_dataloader(self) -> "DataLoader":
     self.labelled._val_input = train_val_split(self._dataset, self.val_split)[1]
     dataloader = self.labelled._val_dataloader()
     dataloader.collate_fn = create_worker_input_transform_processor(
         RunningStage.TRAINING, self.labelled.input_transform
     )
     return dataloader
Ejemplo n.º 7
0
    def _train_dataloader(self) -> DataLoader:
        train_ds: Input = self._train_input

        input_transform = self._resolve_input_transform()

        shuffle: bool = False
        if isinstance(train_ds, IterableDataset):
            drop_last = False
        else:
            drop_last = len(train_ds) > self.batch_size

        if self.sampler is None:
            sampler = None
            shuffle = not isinstance(train_ds, IterableDataset)
        elif callable(self.sampler):
            sampler = self.sampler(train_ds)
        else:
            sampler = self.sampler

        if isinstance(getattr(self, "trainer", None), pl.Trainer) and hasattr(
                self.trainer.lightning_module, "process_train_dataset"):
            dataloader = self.trainer.lightning_module.process_train_dataset(
                train_ds,
                self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                shuffle=shuffle,
                drop_last=drop_last,
                sampler=sampler,
                persistent_workers=self.persistent_workers,
                input_transform=input_transform,
                trainer=self.trainer,
            )
        else:
            dataloader = DataLoader(
                train_ds,
                batch_size=self.batch_size,
                shuffle=shuffle,
                sampler=sampler,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                drop_last=drop_last,
                collate_fn=create_worker_input_transform_processor(
                    RunningStage.TRAINING, input_transform),
                persistent_workers=self.persistent_workers,
            )

        self._on_after_batch_transfer_fns = None
        return dataloader