def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any) -> "CustomDataModule": preprocess = DefaultPreprocess() return cls.from_data_source( "default", train_data=train_data, val_data=val_data, test_data=test_data, predict_data=predict_data, preprocess=preprocess, batch_size=5, )
def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) task = ClassificationTask(model, preprocess=DefaultPreprocess()) ds = DummyDataset() expected = list(range(10)) # single item x0, _ = ds[0] pred0 = task.predict(x0) assert pred0[0] in expected # list x1, _ = ds[1] pred1 = task.predict([x0, x1]) assert all(c in expected for c in pred1) assert pred0[0] == pred1[0]
def __init__( self, data_source: Optional[DataSource] = None, preprocess: Optional[Preprocess] = None, postprocess: Optional[Postprocess] = None, deserializer: Optional[Deserializer] = None, serializer: Optional[Serializer] = None, ) -> None: self.data_source = data_source self._preprocess_pipeline = preprocess or DefaultPreprocess() self._postprocess_pipeline = postprocess or Postprocess() self._serializer = serializer or Serializer() self._deserializer = deserializer or Deserializer() self._running_stage = None
def test_multicrop_input_transform(): batch_size = 8 total_crops = 6 num_crops = [2, 4] size_crops = [160, 96] crop_scales = [[0.4, 1], [0.05, 0.4]] multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( total_crops, num_crops, size_crops, crop_scales) to_tensor_transform = ApplyToKeys( DefaultDataKeys.INPUT, multi_crop_transform, ) preprocess = DefaultPreprocess(train_transform={ "to_tensor_transform": to_tensor_transform, "collate": vissl_collate_fn, }) datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(), preprocess=preprocess, batch_size=batch_size, ) train_dataloader = datamodule._train_dataloader() batch = next(iter(train_dataloader)) assert len(batch[DefaultDataKeys.INPUT]) == total_crops assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size]
def test_preprocess_transforms(tmpdir): """ This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done properly, and collate_in_worker_from_transform is properly extracted. """ with pytest.raises(MisconfigurationException, match="Transform should be a dict."): DefaultPreprocess(train_transform="choco") with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): DefaultPreprocess(train_transform={"choco": None}) preprocess = DefaultPreprocess( train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is None with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): preprocess = DefaultPreprocess( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) preprocess = DefaultPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False train_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING) val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor( RunningStage.VALIDATING) test_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.TESTING) predict_preprocessor = DataPipeline( preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == preprocess.collate assert val_preprocessor.collate_fn.func == preprocess.collate assert test_preprocessor.collate_fn.func == preprocess.collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity class CustomPreprocess(DefaultPreprocess): def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) def per_batch_transform(self, batch: Any) -> Any: return super().per_batch_transform(batch) preprocess = CustomPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, predict_transform={ "per_sample_transform_on_device": torch.nn.Linear(1, 1) }) # keep is None assert preprocess._train_collate_in_worker_from_transform is True assert preprocess._val_collate_in_worker_from_transform is None assert preprocess._test_collate_in_worker_from_transform is None assert preprocess._predict_collate_in_worker_from_transform is False data_pipeline = DataPipeline(preprocess=preprocess) train_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TRAINING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): val_preprocessor = data_pipeline.worker_preprocessor( RunningStage.VALIDATING) with pytest.raises( MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`" ): test_preprocessor = data_pipeline.worker_preprocessor( RunningStage.TESTING) predict_preprocessor = data_pipeline.worker_preprocessor( RunningStage.PREDICTING) assert train_preprocessor.collate_fn.func == preprocess.collate assert predict_preprocessor.collate_fn.func == DataPipeline._identity
def test_flash_callback(_, __, tmpdir): """Test the callback hook system for fit.""" callback_mock = MagicMock() inputs = [[torch.rand(1), torch.rand(1)]] dm = DataModule.from_data_source("default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0) dm.preprocess.callbacks += [callback_mock] _ = next(iter(dm.train_dataloader())) assert callback_mock.method_calls == [ call.on_load_sample(ANY, RunningStage.TRAINING), call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), call.on_to_tensor_transform(ANY, RunningStage.TRAINING), call.on_post_tensor_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), ] class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, limit_train_batches=1, progress_bar_refresh_rate=0, ) dm = DataModule.from_data_source("default", inputs, inputs, inputs, None, preprocess=DefaultPreprocess(), batch_size=1, num_workers=0) dm.preprocess.callbacks += [callback_mock] trainer.fit(CustomModel(), datamodule=dm) assert callback_mock.method_calls == [ call.on_load_sample(ANY, RunningStage.TRAINING), call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), call.on_to_tensor_transform(ANY, RunningStage.TRAINING), call.on_post_tensor_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), call.on_load_sample(ANY, RunningStage.VALIDATING), call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), call.on_load_sample(ANY, RunningStage.TRAINING), call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), call.on_to_tensor_transform(ANY, RunningStage.TRAINING), call.on_post_tensor_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING), call.on_load_sample(ANY, RunningStage.VALIDATING), call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), ]