def test_albumentations_mixup(single_target_csv):
    def mixup(batch, alpha=1.0):
        images = batch["input"]
        targets = batch["target"].float().unsqueeze(1)

        lam = np.random.beta(alpha, alpha)
        perm = torch.randperm(images.size(0))

        batch["input"] = images * lam + images[perm] * (1 - lam)
        batch["target"] = targets * lam + targets[perm] * (1 - lam)
        for e in batch["metadata"]:
            e.update({"lam": lam})
        return batch

    train_transform = {
        # applied only on images as ApplyToKeys is used with `input`
        "post_tensor_transform": ApplyToKeys("input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))),
        "per_batch_transform": mixup,
    }
    # merge the default transform for this task with new one.
    train_transform = merge_transforms(default_transforms((256, 256)), train_transform)

    img_data = ImageClassificationData.from_csv(
        "image",
        "target",
        train_file=single_target_csv,
        batch_size=2,
        num_workers=0,
        train_transform=train_transform,
    )

    batch = next(iter(img_data.train_dataloader()))
    assert "lam" in batch["metadata"][0]
Exemple #2
0
 def default_transforms(self) -> Optional[Dict[str, Callable]]:
     return default_transforms(self.image_size)