def test_albumentations_mixup(single_target_csv): def mixup(batch, alpha=1.0): images = batch["input"] targets = batch["target"].float().unsqueeze(1) lam = np.random.beta(alpha, alpha) perm = torch.randperm(images.size(0)) batch["input"] = images * lam + images[perm] * (1 - lam) batch["target"] = targets * lam + targets[perm] * (1 - lam) for e in batch["metadata"]: e.update({"lam": lam}) return batch train_transform = { # applied only on images as ApplyToKeys is used with `input` "post_tensor_transform": ApplyToKeys("input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))), "per_batch_transform": mixup, } # merge the default transform for this task with new one. train_transform = merge_transforms(default_transforms((256, 256)), train_transform) img_data = ImageClassificationData.from_csv( "image", "target", train_file=single_target_csv, batch_size=2, num_workers=0, train_transform=train_transform, ) batch = next(iter(img_data.train_dataloader())) assert "lam" in batch["metadata"][0]
def test_merge_transforms(base_transforms, additional_transforms, expected_result): result = merge_transforms(base_transforms, additional_transforms) assert result.keys() == expected_result.keys() for key in result.keys(): if result[key] == _MOCK_TRANSFORM: assert expected_result[key] == _MOCK_TRANSFORM elif isinstance(result[key], nn.Sequential): assert isinstance(expected_result[key], nn.Sequential) assert len(result[key]) == len(expected_result[key]) for module, expected_module in zip(result[key], expected_result[key]): assert module.func == expected_module.func
def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """During training, we apply the default transforms with additional ``RandomHorizontalFlip`` and ``ColorJitter``.""" return merge_transforms( default_transforms(image_size), { "post_tensor_transform": nn.Sequential( ApplyToKeys( [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)), ), ), } )
def train_default_transforms( image_size: Tuple[int, int]) -> Dict[str, Callable]: """During training, we apply the default transforms with additional ``RandomHorizontalFlip``.""" if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": # Better approach as all transforms are applied on tensor directly transforms = { "post_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, K.augmentation.RandomHorizontalFlip()), } else: transforms = { "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.RandomHorizontalFlip()) } return merge_transforms(default_transforms(image_size), transforms)
def train_default_transforms( spectrogram_size: Tuple[int, int], time_mask_param: Optional[int], freq_mask_param: Optional[int]) -> Dict[str, Callable]: """During training we apply the default transforms with optional ``TimeMasking`` and ``Frequency Masking``.""" augs = [] if time_mask_param is not None: augs.append( ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param))) if freq_mask_param is not None: augs.append( ApplyToKeys( DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))) if len(augs) > 0: return merge_transforms( default_transforms(spectrogram_size), {"post_tensor_transform": nn.Sequential(*augs)}) return default_transforms(spectrogram_size)
def test_transforms(self, tmpdir): tudataset = TUDataset(root=tmpdir, name="KKI") train_dataset = tudataset val_dataset = tudataset test_dataset = tudataset predict_dataset = tudataset # instantiate the data module dm = GraphClassificationData.from_datasets( train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, predict_dataset=predict_dataset, train_transform=merge_transforms( GraphClassificationPreprocess.default_transforms(), { "pre_tensor_transform": OneHotDegree(tudataset.num_features - 1) }, ), val_transform=merge_transforms( GraphClassificationPreprocess.default_transforms(), { "pre_tensor_transform": OneHotDegree(tudataset.num_features - 1) }, ), test_transform=merge_transforms( GraphClassificationPreprocess.default_transforms(), { "pre_tensor_transform": OneHotDegree(tudataset.num_features - 1) }, ), predict_transform=merge_transforms( GraphClassificationPreprocess.default_transforms(), { "pre_tensor_transform": OneHotDegree(tudataset.num_features - 1) }, ), batch_size=2, ) assert dm is not None assert dm.train_dataloader() is not None assert dm.val_dataloader() is not None assert dm.test_dataloader() is not None # check training data data = next(iter(dm.train_dataloader())) assert list(data.x.size())[1] == tudataset.num_features * 2 assert list(data.y.size()) == [2] # check val data data = next(iter(dm.val_dataloader())) assert list(data.x.size())[1] == tudataset.num_features * 2 assert list(data.y.size()) == [2] # check test data data = next(iter(dm.test_dataloader())) assert list(data.x.size())[1] == tudataset.num_features * 2 assert list(data.y.size()) == [2]