Example #1
0
def test_albumentations_mixup(single_target_csv):
    def mixup(batch, alpha=1.0):
        images = batch["input"]
        targets = batch["target"].float().unsqueeze(1)

        lam = np.random.beta(alpha, alpha)
        perm = torch.randperm(images.size(0))

        batch["input"] = images * lam + images[perm] * (1 - lam)
        batch["target"] = targets * lam + targets[perm] * (1 - lam)
        for e in batch["metadata"]:
            e.update({"lam": lam})
        return batch

    train_transform = {
        # applied only on images as ApplyToKeys is used with `input`
        "post_tensor_transform": ApplyToKeys("input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))),
        "per_batch_transform": mixup,
    }
    # merge the default transform for this task with new one.
    train_transform = merge_transforms(default_transforms((256, 256)), train_transform)

    img_data = ImageClassificationData.from_csv(
        "image",
        "target",
        train_file=single_target_csv,
        batch_size=2,
        num_workers=0,
        train_transform=train_transform,
    )

    batch = next(iter(img_data.train_dataloader()))
    assert "lam" in batch["metadata"][0]
def test_merge_transforms(base_transforms, additional_transforms, expected_result):
    result = merge_transforms(base_transforms, additional_transforms)
    assert result.keys() == expected_result.keys()
    for key in result.keys():
        if result[key] == _MOCK_TRANSFORM:
            assert expected_result[key] == _MOCK_TRANSFORM
        elif isinstance(result[key], nn.Sequential):
            assert isinstance(expected_result[key], nn.Sequential)
            assert len(result[key]) == len(expected_result[key])
            for module, expected_module in zip(result[key], expected_result[key]):
                assert module.func == expected_module.func
Example #3
0
def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
    """During training, we apply the default transforms with additional ``RandomHorizontalFlip`` and ``ColorJitter``."""
    return merge_transforms(
        default_transforms(image_size), {
            "post_tensor_transform": nn.Sequential(
                ApplyToKeys(
                    [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
                    KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)),
                ),
            ),
        }
    )
Example #4
0
def train_default_transforms(
        image_size: Tuple[int, int]) -> Dict[str, Callable]:
    """During training, we apply the default transforms with additional ``RandomHorizontalFlip``."""
    if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1":
        #  Better approach as all transforms are applied on tensor directly
        transforms = {
            "post_tensor_transform":
            ApplyToKeys(DefaultDataKeys.INPUT,
                        K.augmentation.RandomHorizontalFlip()),
        }
    else:
        transforms = {
            "pre_tensor_transform":
            ApplyToKeys(DefaultDataKeys.INPUT, T.RandomHorizontalFlip())
        }
    return merge_transforms(default_transforms(image_size), transforms)
Example #5
0
def train_default_transforms(
        spectrogram_size: Tuple[int, int], time_mask_param: Optional[int],
        freq_mask_param: Optional[int]) -> Dict[str, Callable]:
    """During training we apply the default transforms with optional ``TimeMasking`` and ``Frequency Masking``."""
    augs = []

    if time_mask_param is not None:
        augs.append(
            ApplyToKeys(DefaultDataKeys.INPUT,
                        TAudio.TimeMasking(time_mask_param=time_mask_param)))

    if freq_mask_param is not None:
        augs.append(
            ApplyToKeys(
                DefaultDataKeys.INPUT,
                TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)))

    if len(augs) > 0:
        return merge_transforms(
            default_transforms(spectrogram_size),
            {"post_tensor_transform": nn.Sequential(*augs)})
    return default_transforms(spectrogram_size)
Example #6
0
    def test_transforms(self, tmpdir):
        tudataset = TUDataset(root=tmpdir, name="KKI")
        train_dataset = tudataset
        val_dataset = tudataset
        test_dataset = tudataset
        predict_dataset = tudataset

        # instantiate the data module
        dm = GraphClassificationData.from_datasets(
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            test_dataset=test_dataset,
            predict_dataset=predict_dataset,
            train_transform=merge_transforms(
                GraphClassificationPreprocess.default_transforms(),
                {
                    "pre_tensor_transform":
                    OneHotDegree(tudataset.num_features - 1)
                },
            ),
            val_transform=merge_transforms(
                GraphClassificationPreprocess.default_transforms(),
                {
                    "pre_tensor_transform":
                    OneHotDegree(tudataset.num_features - 1)
                },
            ),
            test_transform=merge_transforms(
                GraphClassificationPreprocess.default_transforms(),
                {
                    "pre_tensor_transform":
                    OneHotDegree(tudataset.num_features - 1)
                },
            ),
            predict_transform=merge_transforms(
                GraphClassificationPreprocess.default_transforms(),
                {
                    "pre_tensor_transform":
                    OneHotDegree(tudataset.num_features - 1)
                },
            ),
            batch_size=2,
        )
        assert dm is not None
        assert dm.train_dataloader() is not None
        assert dm.val_dataloader() is not None
        assert dm.test_dataloader() is not None

        # check training data
        data = next(iter(dm.train_dataloader()))
        assert list(data.x.size())[1] == tudataset.num_features * 2
        assert list(data.y.size()) == [2]

        # check val data
        data = next(iter(dm.val_dataloader()))
        assert list(data.x.size())[1] == tudataset.num_features * 2
        assert list(data.y.size()) == [2]

        # check test data
        data = next(iter(dm.test_dataloader()))
        assert list(data.x.size())[1] == tudataset.num_features * 2
        assert list(data.y.size()) == [2]