Beispiel #1
0
def generate_kornia_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5])
    std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1])
    if torch.cuda.is_available():
        mean=mean.cuda()
        std=std.cuda()
    train_transforms=[G.Resize((resize,resize))]
    if include_jitter:
        train_transforms.append(K.ColorJitter(brightness=0.4, contrast=0.4,
                                   saturation=0.4, hue=0.1))
    train_transforms.extend([K.RandomHorizontalFlip(p=0.5),
           K.RandomVerticalFlip(p=0.5),
           K.RandomRotation(90),
           K.RandomResizedCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ])
    val_transforms=[G.Resize((resize,resize)),
           K.CenterCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ]
    transforms=dict(train=nn.Sequential(*train_transforms),
                val=nn.Sequential(*val_transforms))
    if torch.cuda.is_available():
        for k in transforms:
            transforms[k]=transforms[k].cuda()
    return transforms
Beispiel #2
0
    def __init__(self,
                 net,
                 layer_name_list=['avgpool'],
                 image_size=32,
                 projection_size=256,
                 projection_hidden_size=4096,
                 augment_fn=None,
                 moving_average_decay=0.99,
                 device_='cuda',
                 number_of_classes=10,
                 mean_data=torch.tensor([0.485, 0.456, 0.406]),
                 std_data=torch.tensor([0.229, 0.224, 0.225])):
        super().__init__()

        DEFAULT_AUG = nn.Sequential(
            augs.RandomHorizontalFlip(),
            augs.RandomResizedCrop((image_size, image_size)),
            augs.Normalize(mean=mean_data, std=std_data))

        self.augment = default(augment_fn, DEFAULT_AUG)
        self.device = device_

        self.online_encoder = NetWrapper(net,
                                         projection_size,
                                         projection_hidden_size,
                                         layer_name_list=layer_name_list).to(
                                             self.device)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size).to(self.device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size).to(self.device))
    def __init__(self, net, image_size, hidden_layer=-2, projection_size=256, projection_hidden_size=4096, augment_fn=None, moving_average_decay=0.99):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            color.Normalize(mean=torch.tensor(
                [0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment = default(augment_fn, DEFAULT_AUG)

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MultiLayerPerceptron(projection_size, projection_size, projection_hidden_size)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size))
 def default_train_transforms():
     image_size = ImageClassificationData.image_size
     if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1":
         #  Better approach as all transforms are applied on tensor directly
         return {
             "to_tensor_transform":
             torchvision.transforms.ToTensor(),
             "post_tensor_transform":
             nn.Sequential(K.RandomResizedCrop(image_size),
                           K.RandomHorizontalFlip()),
             "per_batch_transform_on_device":
             nn.Sequential(
                 K.Normalize(torch.tensor([0.485, 0.456, 0.406]),
                             torch.tensor([0.229, 0.224, 0.225])), )
         }
     else:
         from torchvision import transforms as T  # noqa F811
         return {
             "pre_tensor_transform":
             nn.Sequential(T.RandomResizedCrop(image_size),
                           T.RandomHorizontalFlip()),
             "to_tensor_transform":
             torchvision.transforms.ToTensor(),
             "post_tensor_transform":
             T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
         }
    def __init__(self, opt):
        super().__init__()
        self.wrapped_dataset = create_dataset(opt['dataset'])
        self.cropped_img_size = opt['crop_size']
        self.key1 = opt_get(opt, ['key1'], 'hq')
        self.key2 = opt_get(opt, ['key2'], 'lq')
        for_sr = opt_get(
            opt, ['for_sr'],
            False)  # When set, color alterations and blurs are disabled.

        augmentations = [ \
            augs.RandomHorizontalFlip(),
            augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
        if not for_sr:
            augmentations.extend([
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
                augs.RandomGrayscale(p=0.2),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
            ])
        if opt['normalize']:
            # The paper calls for normalization. Most datasets/models in this repo don't use this.
            # Recommend setting true if you want to train exactly like the paper.
            augmentations.append(
                augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                               std=torch.tensor([0.229, 0.224, 0.225])))
        self.aug = nn.Sequential(*augmentations)
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        project_hidden=True,
        project_dim=128,
        augment_both=True,
        use_nt_xent_loss=False,
        augment_fn=None,
        use_bilinear=False,
        use_momentum=False,
        momentum_value=0.999,
        key_encoder=None,
        temperature=0.1,
        fp16=False,
    ):
        super().__init__()
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
        )

        self.augment = default(augment_fn, DEFAULT_AUG)

        self.augment_both = augment_both

        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # for accumulating queries and keys across calls
        self.queries = None
        self.keys = None

        self.fp16 = fp16

        # send a mock image tensor to instantiate parameters
        init = torch.randn(1, 3, image_size, image_size, device="cuda")
        if self.fp16:
            init = init.half()
        self.forward(init)
Beispiel #7
0
def default_aug(image_size: Tuple[int, int] = (360, 360)) -> nn.Module:
    return nn.Sequential(
        aug.ColorJitter(contrast=0.1, brightness=0.1, saturation=0.1, p=0.8),
        aug.RandomVerticalFlip(),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (0.5, 0.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size, scale=(0.5, 1)),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_size = 256,
        projection_hidden_size = 2048,
        augment_fn = None,
        augment_fn2 = None,
        moving_average_decay = 0.99,
        ppm_num_layers = 1,
        ppm_gamma = 2,
        distance_thres = 0.1, # the paper uses 0.7, but that leads to nearly all positive hits. need clarification on how the coordinates are normalized before distance calculation.
        similarity_temperature = 0.3,
        alpha = 1.
    ):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature
        self.alpha = alpha

        self.propagate_pixels = PPM(
            chan = projection_size,
            num_layers = ppm_num_layers,
            gamma = ppm_gamma
        )

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Beispiel #9
0
def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )
 def __init__(self, opt):
     super().__init__()
     self.wrapped_dataset = create_dataset(opt['dataset'])
     self.cropped_img_size = opt['crop_size']
     self.includes_labels = opt['includes_labels']
     augmentations = [ \
         RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
         augs.RandomGrayscale(p=0.2),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
     self.aug = nn.Sequential(*augmentations)
     self.rrc = nn.Sequential(*[
         augs.RandomHorizontalFlip(),
         augs.RandomResizedCrop((self.cropped_img_size,
                                 self.cropped_img_size))
     ])
Beispiel #11
0
 def __init__(self, viz: bool = False):
     super().__init__()
     self.viz = viz
     '''self.geometric = [
         K.augmentation.RandomAffine(60., p=0.75),
     ]'''
     self.augmentations = nn.Sequential(
         augmentation.RandomRotation(degrees=30.),
         augmentation.RandomPerspective(distortion_scale=0.4),
         augmentation.RandomResizedCrop((224, 224)),
         augmentation.RandomHorizontalFlip(p=0.5),
         augmentation.RandomVerticalFlip(p=0.5),
         # K.augmentation.GaussianBlur((3, 3), (0.1, 2.0), p=1.0),
         # K.augmentation.ColorJitter(0.01, 0.01, 0.01, 0.01, p=0.25),
     )
     self.denorm = augmentation.Denormalize(Tensor(DATASET_IMAGE_MEAN), Tensor(DATASET_IMAGE_STD))
Beispiel #12
0
def n_patches(x, n, transform, shape=(64, 64, 3), scale=[0.2, 0.8]):
    if shape[-1] == 0:
        shape = np.random.uniform(64, 128)
        shape = (shape, shape, 3)

    crop = K.RandomResizedCrop(size=(shape[0]), scale=scale, ratio=(0.7, 1.3))
    if torch.is_tensor(x):
        x = x.numpy().transpose(1, 2, 0)

    P = []
    for _ in range(n):
        xx = transform(crop(x))
        P.append(xx)

    # import pdb;

    return torch.cat(P, dim=0)
Beispiel #13
0
 def __init__(self, model, imageSize, embeddingLayer=-2, projectionDim=256, projectionHiddenDim=4096, emaDecay=0.99):
     super(BYOL, self).__init__()
     
     # Default SimCLR augmentations
     self.augment = nn.Sequential(
         RandomApply(augmentation.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
         augmentation.RandomGrayscale(p=0.2),
         augmentation.RandomHorizontalFlip(),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
         augmentation.RandomResizedCrop((imageSize, imageSize)),
         color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
     )
     
     # Initialize models, predictors and EMA
     self.onlineEncoder = ModelWrapper(model, projectionDim, projectionHiddenDim, embeddingLayer)
     self.onlinePredictor = MLP(projectionDim, projectionDim, projectionHiddenDim)
     self.targetEncoder = copy.deepcopy(self.onlineEncoder)
     self.targetEMA = EMA(emaDecay)
    def __init__(self,
                 net,
                 image_size=32,
                 layer_name_list=[-2],
                 projection_size=256,
                 projection_hidden_size=4096,
                 augment_fn=None,
                 moving_average_decay=0.99,
                 device_='cuda',
                 number_of_classes=10,
                 mean_data=torch.tensor([0.485, 0.456, 0.406]),
                 std_data=torch.tensor([0.229, 0.224, 0.225])):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            augs.Normalize(mean=mean_data, std=std_data))

        self.augment = default(augment_fn, DEFAULT_AUG)
        self.device = device_

        self.online_encoder = NetWrapper(net,
                                         projection_size,
                                         projection_hidden_size,
                                         layer_name_list=layer_name_list).to(
                                             self.device)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size).to(self.device)
        self.online_predictor1 = MLP(projection_size, projection_size,
                                     512).to(self.device)
        self.online_predictor2 = MLP(projection_size, projection_size,
                                     512).to(self.device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size).to(self.device))
Beispiel #15
0
    def __init__(self,
                 N_TFMS: int,
                 MAGN: int,
                 mean: Union[tuple, list, torch.tensor],
                 std: Union[tuple, list, torch.tensor],
                 transform_list: list = None,
                 use_resize: int = None,
                 image_size: tuple = None,
                 use_mix: int = None,
                 mix_p: float = .5):
        super().__init__()

        self.N_TFMS, self.MAGN = N_TFMS, MAGN
        self.use_mix, self.mix_p = use_mix, mix_p
        self.image_size = image_size

        if not isinstance(mean, torch.Tensor): mean = torch.Tensor(mean)
        if not isinstance(std, torch.Tensor): std = torch.Tensor(std)

        if self.use_mix is not None:
            self.mix_list = [
                K.RandomCutMix(self.image_size[0], self.image_size[1], p=1),
                K.RandomMixUp(p=1)
            ]

        self.use_resize = use_resize
        if use_resize is not None:
            assert len(
                image_size
            ) == 2, 'Invalid `image_size`. Must be a tuple of form (h, w)'
            self.resize_list = [
                K.RandomResizedCrop(image_size),
                K.RandomCrop(image_size),
                K.CenterCrop(image_size)
            ]
            if self.use_resize < 3:
                self.resize = self.resize_list[use_resize]

        self.normalize = K.Normalize(mean, std)

        self.transform_list = transform_list
        if transform_list is None: self.transform_list = kornia_list(MAGN)
Beispiel #16
0
    def transform(x):
        spatial_jitter = K.RandomResizedCrop(size=shape[:2],
                                             scale=(0.7, 0.9),
                                             ratio=(0.7, 1.3))

        import time
        t0 = time.time()
        x1 = x.unfold(2, 64, 32).unfold(3, 64, 32)
        t1 = time.time()
        x = kornia.contrib.extract_tensor_patches(x,
                                                  window_size=shape[:2],
                                                  stride=stride[:2])
        t2 = time.time()
        print(t2 - t1, t1 - t0)

        T, N, C = x.shape[:3]
        x = transform(spatial_jitter(x.flatten(0, 1))).view(
            T, N * C, *x.shape[3:])

        return x
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        projection_size=256,
        projection_hidden_size=4096,
        moving_average_decay=0.99,
        use_momentum=True,
        structural_mlp=False,
    ):
        super().__init__()

        self.online_encoder = NetWrapper(net,
                                         projection_size,
                                         projection_hidden_size,
                                         layer=hidden_layer,
                                         use_structural_mlp=structural_mlp)

        augmentations = [ \
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size))]
        self.aug = nn.Sequential(*augmentations)
        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device),
                     torch.randn(2, 3, image_size, image_size, device=device))
Beispiel #18
0
def get_frame_transform(frame_transform_str, img_size, cuda=True):
    tt = []

    if 'gray' in frame_transform_str:
        tt += [K.RandomGrayscale(p=1)]

    if 'crop' in frame_transform_str:
        tt += [
            K.RandomResizedCrop(img_size, scale=(0.8, 0.95), ratio=(0.7, 1.3))
        ]
    else:
        tt += [kornia.geometry.transform.Resize((img_size, img_size))]

    if 'cj' in frame_transform_str:
        _cj = 0.1
        tt += [  #K.RandomGrayscale(p=0.2), 
            K.ColorJitter(_cj, _cj, _cj, _cj)
        ]

    if 'flip' in frame_transform_str:
        tt += [K.RandomHorizontalFlip()]

    return tt
Beispiel #19
0
def patch_grid(x, transform, shape=(64, 64, 3), stride=[1.0, 1.0]):
    stride = np.random.random() * (stride[1] - stride[0]) + stride[0]
    stride = [int(shape[0] * stride), int(shape[1] * stride), shape[2]]

    spatial_jitter = K.RandomResizedCrop(size=shape[:2],
                                         scale=(0.7, 0.9),
                                         ratio=(0.7, 1.3))

    import time
    t0 = time.time()
    x1 = x.unfold(2, 64, 32).unfold(3, 64, 32)
    t1 = time.time()
    x = kornia.contrib.extract_tensor_patches(x,
                                              window_size=shape[:2],
                                              stride=stride[:2])
    t2 = time.time()
    print(t2 - t1, t1 - t0)
    # import pdb; pdb.set_trace()
    # x = x.view(3, _sz, _sz, x.shape[-1])

    T, N, C = x.shape[:3]
    x = transform(spatial_jitter(x.flatten(0, 1))).view(T, N * C, *x.shape[3:])

    return x
Beispiel #20
0
    def __init__(self, args=None, vis=None):
        super(TimeCycle, self).__init__()
        
        self.args = args

        if args is not None:
            self.kldv_coef = getattr(args, 'kldv_coef', 0)
            self.xent_coef = getattr(args, 'xent_coef', 0)
            self.zero_diagonal = getattr(args, 'zero_diagonal', 0)
            self.dropout_rate = getattr(args, 'dropout', 0)
            self.featdrop_rate = getattr(args, 'featdrop', 0)
            self.model_type = getattr(args, 'model_type', 'scratch')
            self.temperature = getattr(args, 'temp', getattr(args, 'temperature',1))
            self.shuffle = getattr(args, 'shuffle', 0)
            self.xent_weight = getattr(args, 'xent_weight', False)
        else:
            self.kldv_coef = 0
            self.xent_coef = 0
            self.long_coef = 1
            self.skip_coef = 0

            # self.sk_align = False
            # self.sk_targets = True
            
            self.zero_diagonal = 0
            self.dropout_rate = 0
            self.featdrop_rate = 0
            self.model_type = 'scratch'
            self.temperature = 1
            self.shuffle = False
            self.xent_weight = False

        print('Model temp:', self.temperature)
        self.encoder = utils.make_encoder(args).cuda()


        self.infer_dims()

        self.selfsim_fc = self.make_head(depth=self.garg('head_depth', 0))
        self.selfsim_head = self.make_conv3d_head(depth=1)
        self.context_head = self.make_conv3d_head(depth=1)

        # self.selfsim_head = self.make_head([self.enc_hid_dim, 2*self.enc_hid_dim, self.enc_hid_dim])
        # self.context_head = self.make_head([self.enc_hid_dim, 2*self.enc_hid_dim, self.enc_hid_dim])

        import resnet3d, resnet2d
        if self.garg('cal_coef', 0) > 0:
            self.stack_encoder = utils.make_stack_encoder(self.enc_hid_dim)
            # self.aff_encoder = resnet2d.Bottleneck(1, 128,)

        # # assuming no fc pre-training
        # for m in self.modules():
        #     if isinstance(m, nn.Linear):
        #         m.weight.data.normal_(0, 0.01)
        #         m.bias.data.zero_()

        self.edge = getattr(args, 'edgefunc', 'softmax')

        # self.kldv = torch.nn.KLDivLoss(reduction="batchmean")
        self.kldv = torch.nn.KLDivLoss(reduction="batchmean")
        self.xent = torch.nn.CrossEntropyLoss(reduction="none")

        self.target_temp = 1

        self._xent_targets = {}
        self._kldv_targets = {}
        
        if self.garg('restrict', 0) > 0:
            self.restrict = utils.RestrictAttention(int(args.restrict))
        else:
            self.restrict =  None

        self.dropout = torch.nn.Dropout(p=self.dropout_rate, inplace=False)
        self.featdrop = torch.nn.Dropout(p=self.featdrop_rate, inplace=False)

        self.viz = visdom.Visdom(port=8095, env='%s_%s' % (getattr(args, 'name', 'test'), '')) #int(time.time())))
        self.viz.close()

        if not self.viz.check_connection():
            self.viz = None

        if vis is not None:
            self._viz = vis
    
        p_sz, stride = 64, 32
        self.k_patch =  nn.Sequential(
            K.RandomResizedCrop(size=(p_sz, p_sz), scale=(0.7, 0.9), ratio=(0.7, 1.3))
        )

        mmm, sss = torch.Tensor([0.485, 0.456, 0.406]), torch.Tensor([0.229, 0.224, 0.225])

        self.k_frame = nn.Sequential(
            # kornia.color.Normalize(mean=-mmm/sss, std=1/sss),
            # K.ColorJitter(0.1, 0.1, 0.1, 0),
            # K.RandomResizedCrop(size=(256, 256), scale=(0.8, 0.9), ratio=(0.7, 1.3)),
            # kornia.color.Normalize(mean=mmm, std=sss)
        )
        
        # self.k_frame_same = nn.Sequential(
        #     K.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0), same_on_batch=True)
        # )
        # self.k_frame_same = None
        
        self.k_frame_same = nn.Sequential(
            kornia.geometry.transform.Resize(256 + 20),
            K.RandomHorizontalFlip(same_on_batch=True),
            K.RandomCrop((256, 256), same_on_batch=True),
        )

        self.unfold = torch.nn.Unfold((p_sz,p_sz), dilation=1, padding=0, stride=(stride, stride))

        self.ent_stats = utils.RunningStats(1000)
Beispiel #21
0
def get_frame_transform(args, cuda=True):
    imsz = args.img_size
    norm_size = kornia.geometry.transform.Resize((imsz, imsz))
    norm_imgs = kornia.color.Normalize(mean=IMG_MEAN, std=IMG_STD)

    tt = []
    fts = args.frame_transforms  #.split(',')

    if 'gray' in fts:
        tt.append(K.RandomGrayscale(p=1))

    if 'crop' in fts:
        tt.append(
            K.RandomResizedCrop(imsz, scale=(0.8, 0.95), ratio=(0.7, 1.3)))
    else:
        tt.append(norm_size)

    if 'cj2' in fts:
        _cj = 0.2
        tt += [
            K.RandomGrayscale(p=0.2),
            K.ColorJitter(_cj, _cj, _cj, _cj),
        ]
    elif 'cj' in fts:
        _cj = 0.1
        tt += [
            # K.RandomGrayscale(p=0.2),
            K.ColorJitter(_cj, _cj, _cj, 0),
        ]

    if 'flip' in fts:
        tt += [K.RandomHorizontalFlip()]

    if args.npatch > 1 and args.frame_aug != '':
        tt += [get_frame_aug(args)]
    else:
        tt += [norm_imgs]

    print('Frame transforms:', tt, args.frame_transforms)

    # frame_transform_train = MapTransform(transforms.Compose(tt))
    frame_transform_train = transforms.Compose(tt)
    plain = nn.Sequential(norm_size, norm_imgs)

    def with_orig(x):
        if cuda:
            x = x.cuda()
        if x.max() > 1 and x.min() >= 0:
            x = x.float()
            x -= x.min()
            x /= x.max()
        if x.shape[-1] == 3:
            x = x.permute(0, 3, 1, 2)

        patchify = (not args.visualize) or True

        x = (frame_transform_train(x) if patchify else plain(x)).cpu(), \
                plain(x[0:1]).cpu()

        return x

    return with_orig
Beispiel #22
0
        args.size,
        args.latent_size,
        args.n_mlp,
        channel_multiplier=args.channel_multiplier,
        constant_input=args.constant_input,
    ).to(device)
    g_ema.requires_grad_(False)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    augment_fn = nn.Sequential(
        nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)),  # zoom out
        augs.RandomHorizontalFlip(),
        RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2),
        RandomApply(augs.RandomRotation(180), p=0.2),
        augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)),
        RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1),  # zoom in
        RandomApply(augs.RandomErasing(), p=0.1),
    )
    contrast_learner = (
        ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0))
        if args.contrastive > 0
        else None
    )

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = th.optim.Adam(
        generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        project_hidden=True,
        project_dim=128,
        augment_both=True,
        use_nt_xent_loss=False,
        augment_fn=None,
        use_bilinear=False,
        use_momentum=False,
        momentum_value=0.999,
        key_encoder=None,
        temperature=0.1,
        batch_size=128,
    ):
        super().__init__()
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        DEFAULT_AUG = nn.Sequential(
            # RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            # augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            augs.RandomVerticalFlip(),
            augs.RandomSolarize(),
            augs.RandomPosterize(),
            augs.RandomSharpness(),
            augs.RandomEqualize(),
            augs.RandomRotation(degrees=8.0),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size), p=0.1),
        )
        self.b = batch_size
        self.h = image_size
        self.w = image_size
        self.augment = default(augment_fn, DEFAULT_AUG)

        self.augment_both = augment_both

        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # for accumulating queries and keys across calls
        self.queries = None
        self.keys = None
        random_data = (
            (
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
            ),
            torch.tensor([1]),
        )
        # send a mock image tensor to instantiate parameters
        self.forward(random_data)
Beispiel #25
0
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))
Beispiel #26
0
def get_augmenter(augmenter_type: str,
                  image_size: ImageSizeType,
                  dataset_mean: DatasetStatType,
                  dataset_std: DatasetStatType,
                  padding: PaddingInputType = 1. / 8.,
                  pad_if_needed: bool = False,
                  subset_size: int = 2) -> Union[Module, Callable]:
    """
    
    Args:
        augmenter_type: augmenter type
        image_size: (height, width) image size
        dataset_mean: dataset mean value in CHW
        dataset_std: dataset standard deviation in CHW
        padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided,
            it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is
            used to pad left/right, top/bottom borders, respectively.
        pad_if_needed: bool flag for RandomCrop "pad_if_needed" option
        subset_size: number of augmentations used in subset

    Returns: nn.Module for Kornia augmentation or Callable for torchvision transform

    """
    if not isinstance(padding, tuple):
        assert isinstance(padding, float)
        padding = (padding, padding, padding, padding)

    assert len(padding) == 2 or len(padding) == 4
    if len(padding) == 2:
        # padding of length 2 is used to pad left/right, top/bottom borders, respectively
        # padding of length 4 is used to pad left, top, right, bottom borders respectively
        padding = (padding[0], padding[1], padding[0], padding[1])

    # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders
    padding = (int(image_size[1] * padding[0]), int(
        image_size[0] * padding[1]), int(image_size[1] * padding[2]),
               int(image_size[0] * padding[3]))

    augmenter_type = augmenter_type.strip().lower()

    if augmenter_type == "simple":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type == "fixed":
        return nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomVerticalFlip(p=0.2),
            K.RandomResizedCrop(size=image_size,
                                scale=(0.8, 1.0),
                                ratio=(1., 1.)),
            RandomAugmentation(p=0.5,
                               augmentation=F.GaussianBlur2d(
                                   kernel_size=(3, 3),
                                   sigma=(1.5, 1.5),
                                   border_type='constant')),
            K.ColorJitter(contrast=(0.75, 1.5)),
            # additive Gaussian noise
            K.RandomErasing(p=0.1),
            # Multiply
            K.RandomAffine(degrees=(-25., 25.),
                           translate=(0.2, 0.2),
                           scale=(0.8, 1.2),
                           shear=(-8., 8.)),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type in ["validation", "test"]:
        return nn.Sequential(
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)), )

    elif augmenter_type == "randaugment":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            RandAugmentNS(n=subset_size, m=10),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    else:
        raise NotImplementedError(
            f"\"{augmenter_type}\" is not a supported augmenter type")
def default(val, def_val):
    return def_val if val is None else val

# augmentation utils

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)


# default SimCLR augmentation
image_size = 256
DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)))
            #color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))



if __name__ == '__main__':
    meter = AverageMeter()