def __init__(self, net, image_size, hidden_layer=-2, projection_size=256, projection_hidden_size=4096, augment_fn=None, moving_average_decay=0.99):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            color.Normalize(mean=torch.tensor(
                [0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment = default(augment_fn, DEFAULT_AUG)

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MultiLayerPerceptron(projection_size, projection_size, projection_hidden_size)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size))
Exemplo n.º 2
0
    def __init__(self,
                s_color=0.1, 
                p_color=0.8, 
                p_flip=0.5,
                p_gray=0.2, 
                p_blur=0.5, 
                kernel_min=0.1, 
                kernel_max=2.) -> None:
        super(KorniaAugmentationPipeline, self).__init__()
        
        T_hflip = K.RandomHorizontalFlip(p=p_flip) 
        T_gray = K.RandomGrayscale(p=p_gray)
        T_color = K.ColorJitter(p_color, 0.8*s_color, 0.8*s_color, 0.8*s_color, 0.2*s_color)

        radius = kernel_max*2  # https://dsp.stackexchange.com/questions/10057/gaussian-blur-standard-deviation-radius-and-kernel-size
        kernel_size = int(radius*2 + 1) # needs to be odd.
        kernel_size = (kernel_size, kernel_size)
        T_blur = K.GaussianBlur(kernel_size=kernel_size, sigma=(kernel_min, kernel_max), p=p_blur)
        #T_blur = KorniaRandomGaussianBlur(kernel_size=kernel_size, sigma=(kernel_min, kernel_max), p=p_blur)

        self.transform = nn.Sequential(
            T_hflip,
            T_color,
            T_gray,
            T_blur
        )
Exemplo n.º 3
0
    def __init__(self, net, image_size=32,
                layer_name_list = [-2],
                 projection_size = 256,
                 projection_hidden_size = 4096,
                 augment_fn = None,
                 moving_average_decay = 0.99,
                 device_ = 'cuda',
                 number_of_classes = 10,
                 mean_data = torch.tensor([0.485, 0.456, 0.406]),
                 std_data = torch.tensor([0.229, 0.224, 0.225])):
        super().__init__()


        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            augs.Normalize(mean=mean_data, std=std_data)
        )

        self.augment = default(augment_fn, DEFAULT_AUG)
        self.device = device_

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer_name_list=layer_name_list).to(self.device)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size).to(self.device)
        self.online_predictor1 = MLP(projection_size, projection_size, 512).to(self.device)
        self.online_predictor2 = MLP(projection_size, projection_size, 512).to(self.device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size).to(self.device))
Exemplo n.º 4
0
    def __init__(self, opt):
        super().__init__()
        self.wrapped_dataset = create_dataset(opt['dataset'])
        self.cropped_img_size = opt['crop_size']
        self.key1 = opt_get(opt, ['key1'], 'hq')
        self.key2 = opt_get(opt, ['key2'], 'lq')
        for_sr = opt_get(
            opt, ['for_sr'],
            False)  # When set, color alterations and blurs are disabled.

        augmentations = [ \
            augs.RandomHorizontalFlip(),
            augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
        if not for_sr:
            augmentations.extend([
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
                augs.RandomGrayscale(p=0.2),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
            ])
        if opt['normalize']:
            # The paper calls for normalization. Most datasets/models in this repo don't use this.
            # Recommend setting true if you want to train exactly like the paper.
            augmentations.append(
                augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                               std=torch.tensor([0.229, 0.224, 0.225])))
        self.aug = nn.Sequential(*augmentations)
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        project_hidden=True,
        project_dim=128,
        augment_both=True,
        use_nt_xent_loss=False,
        augment_fn=None,
        use_bilinear=False,
        use_momentum=False,
        momentum_value=0.999,
        key_encoder=None,
        temperature=0.1,
        fp16=False,
    ):
        super().__init__()
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
        )

        self.augment = default(augment_fn, DEFAULT_AUG)

        self.augment_both = augment_both

        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # for accumulating queries and keys across calls
        self.queries = None
        self.keys = None

        self.fp16 = fp16

        # send a mock image tensor to instantiate parameters
        init = torch.randn(1, 3, image_size, image_size, device="cuda")
        if self.fp16:
            init = init.half()
        self.forward(init)
Exemplo n.º 6
0
 def __init__(self, opt):
     super().__init__()
     self.wrapped_dataset = create_dataset(opt['dataset'])
     augmentations = [
         RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
         augs.RandomGrayscale(p=0.2),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
     ]
     self.aug = nn.Sequential(*augmentations)
     self.rrc = RandomSharedRegionCrop(opt['latent_multiple'],
                                       opt_get(opt, ['jitter_range'], 0))
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_size = 256,
        projection_hidden_size = 2048,
        augment_fn = None,
        augment_fn2 = None,
        moving_average_decay = 0.99,
        ppm_num_layers = 1,
        ppm_gamma = 2,
        distance_thres = 0.1, # the paper uses 0.7, but that leads to nearly all positive hits. need clarification on how the coordinates are normalized before distance calculation.
        similarity_temperature = 0.3,
        alpha = 1.
    ):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature
        self.alpha = alpha

        self.propagate_pixels = PPM(
            chan = projection_size,
            num_layers = ppm_num_layers,
            gamma = ppm_gamma
        )

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Exemplo n.º 8
0
def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )
Exemplo n.º 9
0
 def __init__(self, opt):
     super().__init__()
     self.wrapped_dataset = create_dataset(opt['dataset'])
     self.cropped_img_size = opt['crop_size']
     self.includes_labels = opt['includes_labels']
     augmentations = [ \
         RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
         augs.RandomGrayscale(p=0.2),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
     self.aug = nn.Sequential(*augmentations)
     self.rrc = nn.Sequential(*[
         augs.RandomHorizontalFlip(),
         augs.RandomResizedCrop((self.cropped_img_size,
                                 self.cropped_img_size))
     ])
Exemplo n.º 10
0
 def __init__(self, model, imageSize, embeddingLayer=-2, projectionDim=256, projectionHiddenDim=4096, emaDecay=0.99):
     super(BYOL, self).__init__()
     
     # Default SimCLR augmentations
     self.augment = nn.Sequential(
         RandomApply(augmentation.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
         augmentation.RandomGrayscale(p=0.2),
         augmentation.RandomHorizontalFlip(),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
         augmentation.RandomResizedCrop((imageSize, imageSize)),
         color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
     )
     
     # Initialize models, predictors and EMA
     self.onlineEncoder = ModelWrapper(model, projectionDim, projectionHiddenDim, embeddingLayer)
     self.onlinePredictor = MLP(projectionDim, projectionDim, projectionHiddenDim)
     self.targetEncoder = copy.deepcopy(self.onlineEncoder)
     self.targetEMA = EMA(emaDecay)
Exemplo n.º 11
0
    def __init__(self,
                 encoder,
                 predictor,
                 image_size,
                 hidden_layer=-2,
                 projection_size=256,
                 projection_hidden_size=4096,
                 augment_fn=None,
                 augment_fn2=None,
                 moving_average_decay=0.99,
                 use_momentum=True):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            # augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            # augs.RandomResizedCrop((image_size, image_size)),
            # augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        # self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.online_encoder = encoder

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = predictor

        # get device of network and make wrapper same device
        # device = get_module_device(net)
        device = torch.device(2)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Exemplo n.º 12
0
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        projection_size=256,
        projection_hidden_size=4096,
        moving_average_decay=0.99,
        use_momentum=True,
        structural_mlp=False,
    ):
        super().__init__()

        self.online_encoder = NetWrapper(net,
                                         projection_size,
                                         projection_hidden_size,
                                         layer=hidden_layer,
                                         use_structural_mlp=structural_mlp)

        augmentations = [ \
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size))]
        self.aug = nn.Sequential(*augmentations)
        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device),
                     torch.randn(2, 3, image_size, image_size, device=device))
Exemplo n.º 13
0
def get_frame_transform(frame_transform_str, img_size, cuda=True):
    tt = []

    if 'gray' in frame_transform_str:
        tt += [K.RandomGrayscale(p=1)]

    if 'crop' in frame_transform_str:
        tt += [
            K.RandomResizedCrop(img_size, scale=(0.8, 0.95), ratio=(0.7, 1.3))
        ]
    else:
        tt += [kornia.geometry.transform.Resize((img_size, img_size))]

    if 'cj' in frame_transform_str:
        _cj = 0.1
        tt += [  #K.RandomGrayscale(p=0.2), 
            K.ColorJitter(_cj, _cj, _cj, _cj)
        ]

    if 'flip' in frame_transform_str:
        tt += [K.RandomHorizontalFlip()]

    return tt
Exemplo n.º 14
0
def get_frame_transform(args, cuda=True):
    imsz = args.img_size
    norm_size = kornia.geometry.transform.Resize((imsz, imsz))
    norm_imgs = kornia.color.Normalize(mean=IMG_MEAN, std=IMG_STD)

    tt = []
    fts = args.frame_transforms  #.split(',')

    if 'gray' in fts:
        tt.append(K.RandomGrayscale(p=1))

    if 'crop' in fts:
        tt.append(
            K.RandomResizedCrop(imsz, scale=(0.8, 0.95), ratio=(0.7, 1.3)))
    else:
        tt.append(norm_size)

    if 'cj2' in fts:
        _cj = 0.2
        tt += [
            K.RandomGrayscale(p=0.2),
            K.ColorJitter(_cj, _cj, _cj, _cj),
        ]
    elif 'cj' in fts:
        _cj = 0.1
        tt += [
            # K.RandomGrayscale(p=0.2),
            K.ColorJitter(_cj, _cj, _cj, 0),
        ]

    if 'flip' in fts:
        tt += [K.RandomHorizontalFlip()]

    if args.npatch > 1 and args.frame_aug != '':
        tt += [get_frame_aug(args)]
    else:
        tt += [norm_imgs]

    print('Frame transforms:', tt, args.frame_transforms)

    # frame_transform_train = MapTransform(transforms.Compose(tt))
    frame_transform_train = transforms.Compose(tt)
    plain = nn.Sequential(norm_size, norm_imgs)

    def with_orig(x):
        if cuda:
            x = x.cuda()
        if x.max() > 1 and x.min() >= 0:
            x = x.float()
            x -= x.min()
            x /= x.max()
        if x.shape[-1] == 3:
            x = x.permute(0, 3, 1, 2)

        patchify = (not args.visualize) or True

        x = (frame_transform_train(x) if patchify else plain(x)).cpu(), \
                plain(x[0:1]).cpu()

        return x

    return with_orig
Exemplo n.º 15
0
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))
def default(val, def_val):
    return def_val if val is None else val

# augmentation utils

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)


# default SimCLR augmentation
image_size = 256
DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)))
            #color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))



if __name__ == '__main__':
    meter = AverageMeter()
Exemplo n.º 17
0
    def __init__(self,
                 net,
                 image_size,
                 hidden_layer_pixel=-2,
                 hidden_layer_instance=-2,
                 projection_size=256,
                 projection_hidden_size=2048,
                 augment_fn=None,
                 augment_fn2=None,
                 prob_rand_hflip=0.25,
                 moving_average_decay=0.99,
                 ppm_num_layers=1,
                 ppm_gamma=2,
                 distance_thres=0.7,
                 similarity_temperature=0.3,
                 alpha=1.,
                 use_pixpro=True,
                 cutout_ratio_range=(0.6, 0.8),
                 cutout_interpolate_mode='nearest',
                 coord_cutout_interpolate_mode='bilinear'):
        super().__init__()

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomSolarize(p=0.5),
            augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                           std=torch.tensor([0.229, 0.224, 0.225])))

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)
        self.prob_rand_hflip = prob_rand_hflip

        self.online_encoder = NetWrapper(
            net=net,
            projection_size=projection_size,
            projection_hidden_size=projection_hidden_size,
            layer_pixel=hidden_layer_pixel,
            layer_instance=hidden_layer_instance)

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature
        self.alpha = alpha

        self.use_pixpro = use_pixpro

        if use_pixpro:
            self.propagate_pixels = PPM(chan=projection_size,
                                        num_layers=ppm_num_layers,
                                        gamma=ppm_gamma)

        self.cutout_ratio_range = cutout_ratio_range
        self.cutout_interpolate_mode = cutout_interpolate_mode
        self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode

        # instance level predictor
        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Exemplo n.º 18
0
    def __init__(
        self,
        net,
        image_size,
        hidden_layer_pixel=-2,
        hidden_layer_instance=-2,
        instance_projection_size=256,
        instance_projection_hidden_size=2048,
        pix_projection_size=256,
        pix_projection_hidden_size=2048,
        augment_fn=None,
        augment_fn2=None,
        prob_rand_hflip=0.25,
        moving_average_decay=0.99,
        ppm_num_layers=1,
        ppm_gamma=2,
        distance_thres=0.7,
        similarity_temperature=0.3,
        cutout_ratio_range=(0.6, 0.8),
        cutout_interpolate_mode='nearest',
        coord_cutout_interpolate_mode='bilinear',
        max_latent_dim=None  # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
    ):
        super().__init__()

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomSolarize(p=0.5),
            # Normalize left out because it should be done at the model level.
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)
        self.prob_rand_hflip = prob_rand_hflip

        self.online_encoder = NetWrapper(
            net=net,
            instance_projection_size=instance_projection_size,
            instance_projection_hidden_size=instance_projection_hidden_size,
            pix_projection_size=pix_projection_size,
            pix_projection_hidden_size=pix_projection_hidden_size,
            layer_pixel=hidden_layer_pixel,
            layer_instance=hidden_layer_instance)

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature

        # This requirement is due to the way that these are processed, not a hard requirement.
        assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
        self.max_latent_dim = max_latent_dim

        self.propagate_pixels = PPM(chan=pix_projection_size,
                                    num_layers=ppm_num_layers,
                                    gamma=ppm_gamma)

        self.cutout_ratio_range = cutout_ratio_range
        self.cutout_interpolate_mode = cutout_interpolate_mode
        self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode

        # instance level predictor
        self.online_predictor = MLP(instance_projection_size,
                                    instance_projection_size,
                                    instance_projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Exemplo n.º 19
0
def get_gpu_transforms(augs: DictConfig, mode: str = '2d') -> dict:
    """Makes GPU augmentations from the augs section of a configuration.

    Parameters
    ----------
    augs : DictConfig
        Augmentation parameters
    mode : str, optional
        If '2d', stacks clip in channels. If 3d, returns 5-D tensor, by default '2d'

    Returns
    -------
    xform : dict
        keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. 
        Example: auged_images = xform['train'](images)
    """
    # input is a tensor of shape N x C x F x H x W
    train_transforms = [ToFloat()]
    val_transforms = [ToFloat()]
    
    kornia_transforms = []
    
    if augs.LR > 0:
        kornia_transforms.append(K.RandomHorizontalFlip(p=augs.LR,
                                                        same_on_batch=False,
                                                        return_transform=False))
    if augs.UD > 0:
        kornia_transforms.append(K.RandomVerticalFlip(p=augs.UD,
                                                     same_on_batch=False, return_transform=False))
    if augs.degrees > 0:
        kornia_transforms.append(K.RandomRotation(augs.degrees))

    if augs.brightness > 0 or augs.contrast > 0 or augs.saturation > 0 or augs.hue > 0:
        kornia_transforms.append(K.ColorJitter(brightness=augs.brightness,
                                              contrast=augs.contrast, 
                                              saturation=augs.saturation, 
                                              hue=augs.hue, 
                                              p=augs.color_p, 
                                              same_on_batch=False, 
                                              return_transform=False))
    if augs.grayscale > 0:
        kornia_transforms.append(K.RandomGrayscale(p=augs.grayscale))
    
    
    norm = NormalizeVideo(mean=augs.normalization.mean,
                          std=augs.normalization.std)
    # kornia_transforms.append(norm)
    
    kornia_transforms = VideoSequential(*kornia_transforms, 
                                        data_format='BCTHW', 
                                        same_on_frame=True)    
    
    train_transforms = [ToFloat(), 
                        kornia_transforms, 
                        norm]
    val_transforms = [ToFloat(), 
                      norm]

    denormalize = []
    if mode == '2d':
        train_transforms.append(StackClipInChannels())
        val_transforms.append(StackClipInChannels())
        denormalize.append(UnstackClip())
    denormalize.append(DenormalizeVideo(mean=augs.normalization.mean,
                                        std=augs.normalization.std))

    train_transforms = nn.Sequential(*train_transforms)
    val_transforms = nn.Sequential(*val_transforms)
    denormalize = nn.Sequential(*denormalize)

    gpu_transforms = dict(train=train_transforms,
                val=val_transforms,
                test=val_transforms,
                denormalize=denormalize)
    log.info('GPU transforms: {}'.format(gpu_transforms))
    return gpu_transforms
Exemplo n.º 20
0
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)