Example #1
0
def kornia_list(MAGN: int = 4):
    """
    Returns standard list of kornia transforms, each with magnitude `MAGN`.
    
    Args:
        MAGN (int): Magnitude of each transform in the returned list.
    """
    transform_list = [
        # SPATIAL
        K.RandomHorizontalFlip(p=1),
        K.RandomRotation(degrees=90., p=1),
        K.RandomAffine(degrees=MAGN * 5.,
                       shear=MAGN / 5,
                       translate=MAGN / 20,
                       p=1),
        K.RandomPerspective(distortion_scale=MAGN / 25, p=1),

        # PIXEL-LEVEL
        K.ColorJitter(brightness=MAGN / 30, p=1),  # brightness
        K.ColorJitter(saturation=MAGN / 30, p=1),  # saturation
        K.ColorJitter(contrast=MAGN / 30, p=1),  # contrast
        K.ColorJitter(hue=MAGN / 30, p=1),  # hue
        K.ColorJitter(p=0),  # identity
        K.RandomMotionBlur(kernel_size=2 * (MAGN // 3) + 1,
                           angle=MAGN,
                           direction=1.,
                           p=1),
        K.RandomErasing(scale=(MAGN / 100, MAGN / 50),
                        ratio=(MAGN / 20, MAGN),
                        p=1),
    ]
    return transform_list
Example #2
0
def generate_kornia_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5])
    std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1])
    if torch.cuda.is_available():
        mean=mean.cuda()
        std=std.cuda()
    train_transforms=[G.Resize((resize,resize))]
    if include_jitter:
        train_transforms.append(K.ColorJitter(brightness=0.4, contrast=0.4,
                                   saturation=0.4, hue=0.1))
    train_transforms.extend([K.RandomHorizontalFlip(p=0.5),
           K.RandomVerticalFlip(p=0.5),
           K.RandomRotation(90),
           K.RandomResizedCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ])
    val_transforms=[G.Resize((resize,resize)),
           K.CenterCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ]
    transforms=dict(train=nn.Sequential(*train_transforms),
                val=nn.Sequential(*val_transforms))
    if torch.cuda.is_available():
        for k in transforms:
            transforms[k]=transforms[k].cuda()
    return transforms
 def __init__(self, opt):
     super(PostTensorTransform, self).__init__()
     self.random_crop = ProbTransform(A.RandomCrop(
         (opt.input_height, opt.input_width), padding=opt.random_crop),
                                      p=0.8)
     self.random_rotation = ProbTransform(A.RandomRotation(
         opt.random_rotation),
                                          p=0.5)
     if opt.dataset == "cifar10":
         self.random_horizontal_flip = A.RandomHorizontalFlip(p=0.5)
Example #4
0
 def __init__(self, viz: bool = False):
     super().__init__()
     self.viz = viz
     '''self.geometric = [
         K.augmentation.RandomAffine(60., p=0.75),
     ]'''
     self.augmentations = nn.Sequential(
         augmentation.RandomRotation(degrees=30.),
         augmentation.RandomPerspective(distortion_scale=0.4),
         augmentation.RandomResizedCrop((224, 224)),
         augmentation.RandomHorizontalFlip(p=0.5),
         augmentation.RandomVerticalFlip(p=0.5),
         # K.augmentation.GaussianBlur((3, 3), (0.1, 2.0), p=1.0),
         # K.augmentation.ColorJitter(0.01, 0.01, 0.01, 0.01, p=0.25),
     )
     self.denorm = augmentation.Denormalize(Tensor(DATASET_IMAGE_MEAN), Tensor(DATASET_IMAGE_STD))
Example #5
0
    g_ema = Generator(
        args.size,
        args.latent_size,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier,
        constant_input=args.constant_input,
    ).to(device)
    g_ema.requires_grad_(False)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    augment_fn = nn.Sequential(
        nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)),  # zoom out
        augs.RandomHorizontalFlip(),
        RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2),
        RandomApply(augs.RandomRotation(180), p=0.2),
        augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)),
        RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1),  # zoom in
        RandomApply(augs.RandomErasing(), p=0.1),
    )
    contrast_learner = (
        ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0))
        if args.contrastive > 0
        else None
    )

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = th.optim.Adam(
        generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
Example #6
0
        self.targets = y
        self.transform = transform
    def __len__(self):
        return len(self.targets)
    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.targets[idx]
        if self.transform is not None:
            img = self.transform(img)
        return img, label

# %% DataLoaders
from torchvision import transforms    
import kornia.augmentation as k

tfms = transforms.Compose([transforms.RandomApply([k.RandomRotation(45.)],p=0.7),
                           transforms.Normalize((0.5,),(0.5,))])

#create a master dataset then split it into train and validation
dsm = MyDataset(img,label,transform=tfms) # Master Dataset
ds,dv = torch.utils.data.random_split(dsm,[60000,10240])

dl = DataLoader(ds,batch_size=100,shuffle=True) # training dataloader
dlv = DataLoader(dv,batch_size=100)             # validation dataloaderr

sampler = torch.utils.data.RandomSampler(ds,replacement=True,num_samples=500)
dls = DataLoader(ds,batch_size=100,sampler=sampler)

# %% Mixed1D2D Network
class Net(nn.Module):
    def __init__(self,out,kernel):
def main(
    opt_alg,
    opt_args,
    clip_args,
    max_epochs,
    batch_size,
    latent_dim,
    _seed,
    _run,
    eval_every,
):
    # pyro.enable_validation(True)

    ds_train, ds_test = get_datasets()
    train_dl = torch.utils.data.DataLoader(ds_train,
                                           batch_size=batch_size,
                                           num_workers=4,
                                           shuffle=True)
    test_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=batch_size,
                                          num_workers=4)

    transforms = T.TransformSequence(T.Rotation())

    trs = transformers.TransformerSequence(
        transformers.Rotation(networks.EquivariantPosePredictor, 1, 32))

    encoder = TransformingEncoder(trs, latent_dim=latent_dim)
    encoder = encoder.cuda()
    decoder = VaeResViewDecoder(latent_dim=latent_dim)
    decoder.cuda()

    svi_args = {
        "encoder": encoder,
        "decoder": decoder,
        "instantiate_label": True,
        "transforms": transforms,
        "cond": True,
        "output_size": 128,
        "device": torch.device("cuda"),
    }

    opt_alg = get_opt_alg(opt_alg)
    opt = opt_alg(opt_args, clip_args=clip_args)
    elbo = infer.Trace_ELBO(max_plate_nesting=1)

    svi = infer.SVI(forward_model, backward_model, opt, loss=elbo)

    if _run.unobserved or _run._id is None:
        tb = U.DummyWriter("/tmp/delme")
    else:
        tb = SummaryWriter(
            U.setup_run_directory(
                Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id)))
        _run.info["tensorboard"] = tb.log_dir

    for batch in train_dl:
        x = batch[0]
        x_orig = x.cuda()
        break

    for i in range(10000):
        encoder.train()
        decoder.train()
        x = augmentation.RandomRotation(180.0)(x_orig)
        l = svi.step(x, **svi_args)

        if i % 200 == 0:
            encoder.eval()
            decoder.eval()

            print("EPOCH", i, "LOSS", l)
            ex.log_scalar("train.loss", l, i)
            tb.add_scalar("train/loss", l, i)
            tb.add_image(f"train/originals", torchvision.utils.make_grid(x), i)
            bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args)
            fwd_trace = poutine.trace(
                poutine.replay(forward_model,
                               trace=bwd_trace)).get_trace(x, **svi_args)
            recon = fwd_trace.nodes["pixels"]["fn"].mean
            tb.add_image(f"train/recons", torchvision.utils.make_grid(recon),
                         i)

            canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
            tb.add_image(
                f"train/canonical_recon",
                torchvision.utils.make_grid(canonical_recon),
                i,
            )

            # sample from the prior

            prior_sample_args = {}
            prior_sample_args.update(svi_args)
            prior_sample_args["cond"] = False
            prior_sample_args["cond_label"] = False
            fwd_trace = poutine.trace(forward_model).get_trace(
                x, **prior_sample_args)
            prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
            prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
            tb.add_image(f"train/prior_samples",
                         torchvision.utils.make_grid(prior_sample), i)

            tb.add_image(
                f"train/canonical_prior_samples",
                torchvision.utils.make_grid(prior_canonical_sample),
                i,
            )
            tb.add_image(
                f"train/input_view",
                torchvision.utils.make_grid(
                    bwd_trace.nodes["attention_input"]["value"]),
                i,
            )
Example #8
0
 def __init__(self, degree: float = 5.0):
     self._degree = degree
     self._operation = aug.RandomRotation(degrees=degree)
Example #9
0
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))
                                                 download=True)
val_dataset = l2l.vision.datasets.MiniImagenet(root="data",
                                               mode="validation",
                                               download=True)

transform = {
    "per_sample_transform":
    nn.Sequential(
        ApplyToKeys(
            DataKeys.INPUT,
            nn.Sequential(
                torchvision.transforms.ToTensor(),
                Kg.Resize((196, 196)),
                # SPATIAL
                Ka.RandomHorizontalFlip(p=0.25),
                Ka.RandomRotation(degrees=90.0, p=0.25),
                Ka.RandomAffine(degrees=1 * 5.0,
                                shear=1 / 5,
                                translate=1 / 20,
                                p=0.25),
                Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
                # PIXEL-LEVEL
                Ka.ColorJitter(brightness=1 / 30, p=0.25),  # brightness
                Ka.ColorJitter(saturation=1 / 30, p=0.25),  # saturation
                Ka.ColorJitter(contrast=1 / 30, p=0.25),  # contrast
                Ka.ColorJitter(hue=1 / 30, p=0.25),  # hue
                Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1,
                                    angle=1,
                                    direction=1.0,
                                    p=0.25),
                Ka.RandomErasing(scale=(1 / 100, 1 / 50),
Example #11
0
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)
Example #12
0
def get_gpu_transforms(augs: DictConfig, mode: str = '2d') -> dict:
    """Makes GPU augmentations from the augs section of a configuration.

    Parameters
    ----------
    augs : DictConfig
        Augmentation parameters
    mode : str, optional
        If '2d', stacks clip in channels. If 3d, returns 5-D tensor, by default '2d'

    Returns
    -------
    xform : dict
        keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. 
        Example: auged_images = xform['train'](images)
    """
    # input is a tensor of shape N x C x F x H x W
    train_transforms = [ToFloat()]
    val_transforms = [ToFloat()]
    
    kornia_transforms = []
    
    if augs.LR > 0:
        kornia_transforms.append(K.RandomHorizontalFlip(p=augs.LR,
                                                        same_on_batch=False,
                                                        return_transform=False))
    if augs.UD > 0:
        kornia_transforms.append(K.RandomVerticalFlip(p=augs.UD,
                                                     same_on_batch=False, return_transform=False))
    if augs.degrees > 0:
        kornia_transforms.append(K.RandomRotation(augs.degrees))

    if augs.brightness > 0 or augs.contrast > 0 or augs.saturation > 0 or augs.hue > 0:
        kornia_transforms.append(K.ColorJitter(brightness=augs.brightness,
                                              contrast=augs.contrast, 
                                              saturation=augs.saturation, 
                                              hue=augs.hue, 
                                              p=augs.color_p, 
                                              same_on_batch=False, 
                                              return_transform=False))
    if augs.grayscale > 0:
        kornia_transforms.append(K.RandomGrayscale(p=augs.grayscale))
    
    
    norm = NormalizeVideo(mean=augs.normalization.mean,
                          std=augs.normalization.std)
    # kornia_transforms.append(norm)
    
    kornia_transforms = VideoSequential(*kornia_transforms, 
                                        data_format='BCTHW', 
                                        same_on_frame=True)    
    
    train_transforms = [ToFloat(), 
                        kornia_transforms, 
                        norm]
    val_transforms = [ToFloat(), 
                      norm]

    denormalize = []
    if mode == '2d':
        train_transforms.append(StackClipInChannels())
        val_transforms.append(StackClipInChannels())
        denormalize.append(UnstackClip())
    denormalize.append(DenormalizeVideo(mean=augs.normalization.mean,
                                        std=augs.normalization.std))

    train_transforms = nn.Sequential(*train_transforms)
    val_transforms = nn.Sequential(*val_transforms)
    denormalize = nn.Sequential(*denormalize)

    gpu_transforms = dict(train=train_transforms,
                val=val_transforms,
                test=val_transforms,
                denormalize=denormalize)
    log.info('GPU transforms: {}'.format(gpu_transforms))
    return gpu_transforms
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        project_hidden=True,
        project_dim=128,
        augment_both=True,
        use_nt_xent_loss=False,
        augment_fn=None,
        use_bilinear=False,
        use_momentum=False,
        momentum_value=0.999,
        key_encoder=None,
        temperature=0.1,
        batch_size=128,
    ):
        super().__init__()
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        DEFAULT_AUG = nn.Sequential(
            # RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            # augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            augs.RandomVerticalFlip(),
            augs.RandomSolarize(),
            augs.RandomPosterize(),
            augs.RandomSharpness(),
            augs.RandomEqualize(),
            augs.RandomRotation(degrees=8.0),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size), p=0.1),
        )
        self.b = batch_size
        self.h = image_size
        self.w = image_size
        self.augment = default(augment_fn, DEFAULT_AUG)

        self.augment_both = augment_both

        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # for accumulating queries and keys across calls
        self.queries = None
        self.keys = None
        random_data = (
            (
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
            ),
            torch.tensor([1]),
        )
        # send a mock image tensor to instantiate parameters
        self.forward(random_data)