def generate_kornia_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False): mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5]) std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1]) if torch.cuda.is_available(): mean=mean.cuda() std=std.cuda() train_transforms=[G.Resize((resize,resize))] if include_jitter: train_transforms.append(K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)) train_transforms.extend([K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomRotation(90), K.RandomResizedCrop((image_size,image_size)), K.Normalize(mean,std) ]) val_transforms=[G.Resize((resize,resize)), K.CenterCrop((image_size,image_size)), K.Normalize(mean,std) ] transforms=dict(train=nn.Sequential(*train_transforms), val=nn.Sequential(*val_transforms)) if torch.cuda.is_available(): for k in transforms: transforms[k]=transforms[k].cuda() return transforms
def __init__(self, net, layer_name_list=['avgpool'], image_size=32, projection_size=256, projection_hidden_size=4096, augment_fn=None, moving_average_decay=0.99, device_='cuda', number_of_classes=10, mean_data=torch.tensor([0.485, 0.456, 0.406]), std_data=torch.tensor([0.229, 0.224, 0.225])): super().__init__() DEFAULT_AUG = nn.Sequential( augs.RandomHorizontalFlip(), augs.RandomResizedCrop((image_size, image_size)), augs.Normalize(mean=mean_data, std=std_data)) self.augment = default(augment_fn, DEFAULT_AUG) self.device = device_ self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer_name_list=layer_name_list).to( self.device) self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size).to(self.device) # send a mock image tensor to instantiate singleton parameters self.forward(torch.randn(2, 3, image_size, image_size).to(self.device))
def __init__(self, net, image_size, hidden_layer=-2, projection_size=256, projection_hidden_size=4096, augment_fn=None, moving_average_decay=0.99): super().__init__() # default SimCLR augmentation DEFAULT_AUG = nn.Sequential( RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size)), color.Normalize(mean=torch.tensor( [0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) ) self.augment = default(augment_fn, DEFAULT_AUG) self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.online_predictor = MultiLayerPerceptron(projection_size, projection_size, projection_hidden_size) # send a mock image tensor to instantiate singleton parameters self.forward(torch.randn(2, 3, image_size, image_size))
def default_train_transforms(): image_size = ImageClassificationData.image_size if _KORNIA_AVAILABLE and not os.getenv("FLASH_TESTING", "0") == "1": # Better approach as all transforms are applied on tensor directly return { "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": nn.Sequential(K.RandomResizedCrop(image_size), K.RandomHorizontalFlip()), "per_batch_transform_on_device": nn.Sequential( K.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ) } else: from torchvision import transforms as T # noqa F811 return { "pre_tensor_transform": nn.Sequential(T.RandomResizedCrop(image_size), T.RandomHorizontalFlip()), "to_tensor_transform": torchvision.transforms.ToTensor(), "post_tensor_transform": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), }
def __init__(self, opt): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) self.cropped_img_size = opt['crop_size'] self.key1 = opt_get(opt, ['key1'], 'hq') self.key2 = opt_get(opt, ['key2'], 'lq') for_sr = opt_get( opt, ['for_sr'], False) # When set, color alterations and blurs are disabled. augmentations = [ \ augs.RandomHorizontalFlip(), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))] if not for_sr: augmentations.extend([ RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1) ]) if opt['normalize']: # The paper calls for normalization. Most datasets/models in this repo don't use this. # Recommend setting true if you want to train exactly like the paper. augmentations.append( augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) self.aug = nn.Sequential(*augmentations)
def __init__( self, net, image_size, hidden_layer=-2, project_hidden=True, project_dim=128, augment_both=True, use_nt_xent_loss=False, augment_fn=None, use_bilinear=False, use_momentum=False, momentum_value=0.999, key_encoder=None, temperature=0.1, fp16=False, ): super().__init__() self.net = OutputHiddenLayer(net, layer=hidden_layer) DEFAULT_AUG = nn.Sequential( RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size)), ) self.augment = default(augment_fn, DEFAULT_AUG) self.augment_both = augment_both self.temperature = temperature self.use_nt_xent_loss = use_nt_xent_loss self.project_hidden = project_hidden self.projection = None self.project_dim = project_dim self.use_bilinear = use_bilinear self.bilinear_w = None self.use_momentum = use_momentum self.ema_updater = EMA(momentum_value) self.key_encoder = key_encoder # for accumulating queries and keys across calls self.queries = None self.keys = None self.fp16 = fp16 # send a mock image tensor to instantiate parameters init = torch.randn(1, 3, image_size, image_size, device="cuda") if self.fp16: init = init.half() self.forward(init)
def default_aug(image_size: Tuple[int, int] = (360, 360)) -> nn.Module: return nn.Sequential( aug.ColorJitter(contrast=0.1, brightness=0.1, saturation=0.1, p=0.8), aug.RandomVerticalFlip(), aug.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (0.5, 0.5)), p=0.1), aug.RandomResizedCrop(size=image_size, scale=(0.5, 1)), aug.Normalize( mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]), ), )
def __init__( self, net, image_size, hidden_layer = -2, projection_size = 256, projection_hidden_size = 2048, augment_fn = None, augment_fn2 = None, moving_average_decay = 0.99, ppm_num_layers = 1, ppm_gamma = 2, distance_thres = 0.1, # the paper uses 0.7, but that leads to nearly all positive hits. need clarification on how the coordinates are normalized before distance calculation. similarity_temperature = 0.3, alpha = 1. ): super().__init__() # default SimCLR augmentation DEFAULT_AUG = nn.Sequential( RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size)), augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) ) self.augment1 = default(augment_fn, DEFAULT_AUG) self.augment2 = default(augment_fn2, self.augment1) self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.distance_thres = distance_thres self.similarity_temperature = similarity_temperature self.alpha = alpha self.propagate_pixels = PPM( chan = projection_size, num_layers = ppm_num_layers, gamma = ppm_gamma ) # get device of network and make wrapper same device device = get_module_device(net) self.to(device) # send a mock image tensor to instantiate singleton parameters self.forward(torch.randn(2, 3, image_size, image_size, device=device))
def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: return nn.Sequential( tf.Resize(size=image_size), RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), aug.RandomGrayscale(p=0.2), aug.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), aug.RandomResizedCrop(size=image_size), aug.Normalize( mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]), ), )
def __init__(self, opt): super().__init__() self.wrapped_dataset = create_dataset(opt['dataset']) self.cropped_img_size = opt['crop_size'] self.includes_labels = opt['includes_labels'] augmentations = [ \ RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)] self.aug = nn.Sequential(*augmentations) self.rrc = nn.Sequential(*[ augs.RandomHorizontalFlip(), augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size)) ])
def __init__(self, viz: bool = False): super().__init__() self.viz = viz '''self.geometric = [ K.augmentation.RandomAffine(60., p=0.75), ]''' self.augmentations = nn.Sequential( augmentation.RandomRotation(degrees=30.), augmentation.RandomPerspective(distortion_scale=0.4), augmentation.RandomResizedCrop((224, 224)), augmentation.RandomHorizontalFlip(p=0.5), augmentation.RandomVerticalFlip(p=0.5), # K.augmentation.GaussianBlur((3, 3), (0.1, 2.0), p=1.0), # K.augmentation.ColorJitter(0.01, 0.01, 0.01, 0.01, p=0.25), ) self.denorm = augmentation.Denormalize(Tensor(DATASET_IMAGE_MEAN), Tensor(DATASET_IMAGE_STD))
def n_patches(x, n, transform, shape=(64, 64, 3), scale=[0.2, 0.8]): if shape[-1] == 0: shape = np.random.uniform(64, 128) shape = (shape, shape, 3) crop = K.RandomResizedCrop(size=(shape[0]), scale=scale, ratio=(0.7, 1.3)) if torch.is_tensor(x): x = x.numpy().transpose(1, 2, 0) P = [] for _ in range(n): xx = transform(crop(x)) P.append(xx) # import pdb; return torch.cat(P, dim=0)
def __init__(self, model, imageSize, embeddingLayer=-2, projectionDim=256, projectionHiddenDim=4096, emaDecay=0.99): super(BYOL, self).__init__() # Default SimCLR augmentations self.augment = nn.Sequential( RandomApply(augmentation.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augmentation.RandomGrayscale(p=0.2), augmentation.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augmentation.RandomResizedCrop((imageSize, imageSize)), color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) ) # Initialize models, predictors and EMA self.onlineEncoder = ModelWrapper(model, projectionDim, projectionHiddenDim, embeddingLayer) self.onlinePredictor = MLP(projectionDim, projectionDim, projectionHiddenDim) self.targetEncoder = copy.deepcopy(self.onlineEncoder) self.targetEMA = EMA(emaDecay)
def __init__(self, net, image_size=32, layer_name_list=[-2], projection_size=256, projection_hidden_size=4096, augment_fn=None, moving_average_decay=0.99, device_='cuda', number_of_classes=10, mean_data=torch.tensor([0.485, 0.456, 0.406]), std_data=torch.tensor([0.229, 0.224, 0.225])): super().__init__() # default SimCLR augmentation DEFAULT_AUG = nn.Sequential( RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size)), augs.Normalize(mean=mean_data, std=std_data)) self.augment = default(augment_fn, DEFAULT_AUG) self.device = device_ self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer_name_list=layer_name_list).to( self.device) self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size).to(self.device) self.online_predictor1 = MLP(projection_size, projection_size, 512).to(self.device) self.online_predictor2 = MLP(projection_size, projection_size, 512).to(self.device) # send a mock image tensor to instantiate singleton parameters self.forward(torch.randn(2, 3, image_size, image_size).to(self.device))
def __init__(self, N_TFMS: int, MAGN: int, mean: Union[tuple, list, torch.tensor], std: Union[tuple, list, torch.tensor], transform_list: list = None, use_resize: int = None, image_size: tuple = None, use_mix: int = None, mix_p: float = .5): super().__init__() self.N_TFMS, self.MAGN = N_TFMS, MAGN self.use_mix, self.mix_p = use_mix, mix_p self.image_size = image_size if not isinstance(mean, torch.Tensor): mean = torch.Tensor(mean) if not isinstance(std, torch.Tensor): std = torch.Tensor(std) if self.use_mix is not None: self.mix_list = [ K.RandomCutMix(self.image_size[0], self.image_size[1], p=1), K.RandomMixUp(p=1) ] self.use_resize = use_resize if use_resize is not None: assert len( image_size ) == 2, 'Invalid `image_size`. Must be a tuple of form (h, w)' self.resize_list = [ K.RandomResizedCrop(image_size), K.RandomCrop(image_size), K.CenterCrop(image_size) ] if self.use_resize < 3: self.resize = self.resize_list[use_resize] self.normalize = K.Normalize(mean, std) self.transform_list = transform_list if transform_list is None: self.transform_list = kornia_list(MAGN)
def transform(x): spatial_jitter = K.RandomResizedCrop(size=shape[:2], scale=(0.7, 0.9), ratio=(0.7, 1.3)) import time t0 = time.time() x1 = x.unfold(2, 64, 32).unfold(3, 64, 32) t1 = time.time() x = kornia.contrib.extract_tensor_patches(x, window_size=shape[:2], stride=stride[:2]) t2 = time.time() print(t2 - t1, t1 - t0) T, N, C = x.shape[:3] x = transform(spatial_jitter(x.flatten(0, 1))).view( T, N * C, *x.shape[3:]) return x
def __init__( self, net, image_size, hidden_layer=-2, projection_size=256, projection_hidden_size=4096, moving_average_decay=0.99, use_momentum=True, structural_mlp=False, ): super().__init__() self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_structural_mlp=structural_mlp) augmentations = [ \ RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size))] self.aug = nn.Sequential(*augmentations) self.use_momentum = use_momentum self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) # get device of network and make wrapper same device device = get_module_device(net) self.to(device) # send a mock image tensor to instantiate singleton parameters self.forward(torch.randn(2, 3, image_size, image_size, device=device), torch.randn(2, 3, image_size, image_size, device=device))
def get_frame_transform(frame_transform_str, img_size, cuda=True): tt = [] if 'gray' in frame_transform_str: tt += [K.RandomGrayscale(p=1)] if 'crop' in frame_transform_str: tt += [ K.RandomResizedCrop(img_size, scale=(0.8, 0.95), ratio=(0.7, 1.3)) ] else: tt += [kornia.geometry.transform.Resize((img_size, img_size))] if 'cj' in frame_transform_str: _cj = 0.1 tt += [ #K.RandomGrayscale(p=0.2), K.ColorJitter(_cj, _cj, _cj, _cj) ] if 'flip' in frame_transform_str: tt += [K.RandomHorizontalFlip()] return tt
def patch_grid(x, transform, shape=(64, 64, 3), stride=[1.0, 1.0]): stride = np.random.random() * (stride[1] - stride[0]) + stride[0] stride = [int(shape[0] * stride), int(shape[1] * stride), shape[2]] spatial_jitter = K.RandomResizedCrop(size=shape[:2], scale=(0.7, 0.9), ratio=(0.7, 1.3)) import time t0 = time.time() x1 = x.unfold(2, 64, 32).unfold(3, 64, 32) t1 = time.time() x = kornia.contrib.extract_tensor_patches(x, window_size=shape[:2], stride=stride[:2]) t2 = time.time() print(t2 - t1, t1 - t0) # import pdb; pdb.set_trace() # x = x.view(3, _sz, _sz, x.shape[-1]) T, N, C = x.shape[:3] x = transform(spatial_jitter(x.flatten(0, 1))).view(T, N * C, *x.shape[3:]) return x
def __init__(self, args=None, vis=None): super(TimeCycle, self).__init__() self.args = args if args is not None: self.kldv_coef = getattr(args, 'kldv_coef', 0) self.xent_coef = getattr(args, 'xent_coef', 0) self.zero_diagonal = getattr(args, 'zero_diagonal', 0) self.dropout_rate = getattr(args, 'dropout', 0) self.featdrop_rate = getattr(args, 'featdrop', 0) self.model_type = getattr(args, 'model_type', 'scratch') self.temperature = getattr(args, 'temp', getattr(args, 'temperature',1)) self.shuffle = getattr(args, 'shuffle', 0) self.xent_weight = getattr(args, 'xent_weight', False) else: self.kldv_coef = 0 self.xent_coef = 0 self.long_coef = 1 self.skip_coef = 0 # self.sk_align = False # self.sk_targets = True self.zero_diagonal = 0 self.dropout_rate = 0 self.featdrop_rate = 0 self.model_type = 'scratch' self.temperature = 1 self.shuffle = False self.xent_weight = False print('Model temp:', self.temperature) self.encoder = utils.make_encoder(args).cuda() self.infer_dims() self.selfsim_fc = self.make_head(depth=self.garg('head_depth', 0)) self.selfsim_head = self.make_conv3d_head(depth=1) self.context_head = self.make_conv3d_head(depth=1) # self.selfsim_head = self.make_head([self.enc_hid_dim, 2*self.enc_hid_dim, self.enc_hid_dim]) # self.context_head = self.make_head([self.enc_hid_dim, 2*self.enc_hid_dim, self.enc_hid_dim]) import resnet3d, resnet2d if self.garg('cal_coef', 0) > 0: self.stack_encoder = utils.make_stack_encoder(self.enc_hid_dim) # self.aff_encoder = resnet2d.Bottleneck(1, 128,) # # assuming no fc pre-training # for m in self.modules(): # if isinstance(m, nn.Linear): # m.weight.data.normal_(0, 0.01) # m.bias.data.zero_() self.edge = getattr(args, 'edgefunc', 'softmax') # self.kldv = torch.nn.KLDivLoss(reduction="batchmean") self.kldv = torch.nn.KLDivLoss(reduction="batchmean") self.xent = torch.nn.CrossEntropyLoss(reduction="none") self.target_temp = 1 self._xent_targets = {} self._kldv_targets = {} if self.garg('restrict', 0) > 0: self.restrict = utils.RestrictAttention(int(args.restrict)) else: self.restrict = None self.dropout = torch.nn.Dropout(p=self.dropout_rate, inplace=False) self.featdrop = torch.nn.Dropout(p=self.featdrop_rate, inplace=False) self.viz = visdom.Visdom(port=8095, env='%s_%s' % (getattr(args, 'name', 'test'), '')) #int(time.time()))) self.viz.close() if not self.viz.check_connection(): self.viz = None if vis is not None: self._viz = vis p_sz, stride = 64, 32 self.k_patch = nn.Sequential( K.RandomResizedCrop(size=(p_sz, p_sz), scale=(0.7, 0.9), ratio=(0.7, 1.3)) ) mmm, sss = torch.Tensor([0.485, 0.456, 0.406]), torch.Tensor([0.229, 0.224, 0.225]) self.k_frame = nn.Sequential( # kornia.color.Normalize(mean=-mmm/sss, std=1/sss), # K.ColorJitter(0.1, 0.1, 0.1, 0), # K.RandomResizedCrop(size=(256, 256), scale=(0.8, 0.9), ratio=(0.7, 1.3)), # kornia.color.Normalize(mean=mmm, std=sss) ) # self.k_frame_same = nn.Sequential( # K.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0), same_on_batch=True) # ) # self.k_frame_same = None self.k_frame_same = nn.Sequential( kornia.geometry.transform.Resize(256 + 20), K.RandomHorizontalFlip(same_on_batch=True), K.RandomCrop((256, 256), same_on_batch=True), ) self.unfold = torch.nn.Unfold((p_sz,p_sz), dilation=1, padding=0, stride=(stride, stride)) self.ent_stats = utils.RunningStats(1000)
def get_frame_transform(args, cuda=True): imsz = args.img_size norm_size = kornia.geometry.transform.Resize((imsz, imsz)) norm_imgs = kornia.color.Normalize(mean=IMG_MEAN, std=IMG_STD) tt = [] fts = args.frame_transforms #.split(',') if 'gray' in fts: tt.append(K.RandomGrayscale(p=1)) if 'crop' in fts: tt.append( K.RandomResizedCrop(imsz, scale=(0.8, 0.95), ratio=(0.7, 1.3))) else: tt.append(norm_size) if 'cj2' in fts: _cj = 0.2 tt += [ K.RandomGrayscale(p=0.2), K.ColorJitter(_cj, _cj, _cj, _cj), ] elif 'cj' in fts: _cj = 0.1 tt += [ # K.RandomGrayscale(p=0.2), K.ColorJitter(_cj, _cj, _cj, 0), ] if 'flip' in fts: tt += [K.RandomHorizontalFlip()] if args.npatch > 1 and args.frame_aug != '': tt += [get_frame_aug(args)] else: tt += [norm_imgs] print('Frame transforms:', tt, args.frame_transforms) # frame_transform_train = MapTransform(transforms.Compose(tt)) frame_transform_train = transforms.Compose(tt) plain = nn.Sequential(norm_size, norm_imgs) def with_orig(x): if cuda: x = x.cuda() if x.max() > 1 and x.min() >= 0: x = x.float() x -= x.min() x /= x.max() if x.shape[-1] == 3: x = x.permute(0, 3, 1, 2) patchify = (not args.visualize) or True x = (frame_transform_train(x) if patchify else plain(x)).cpu(), \ plain(x[0:1]).cpu() return x return with_orig
args.size, args.latent_size, args.n_mlp, channel_multiplier=args.channel_multiplier, constant_input=args.constant_input, ).to(device) g_ema.requires_grad_(False) g_ema.eval() accumulate(g_ema, generator, 0) augment_fn = nn.Sequential( nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)), # zoom out augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2), RandomApply(augs.RandomRotation(180), p=0.2), augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)), RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1), # zoom in RandomApply(augs.RandomErasing(), p=0.1), ) contrast_learner = ( ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0)) if args.contrastive > 0 else None ) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = th.optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), )
def __init__( self, image_size, latent_dim=512, style_depth=8, network_capacity=16, transparent=False, fp16=False, cl_reg=False, augment_fn=None, steps=1, lr=1e-4, fq_layers=[], fq_dict_size=256, attn_layers=[], ): super().__init__() self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) self.S = StyleVectorizer(latent_dim, style_depth) self.G = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) self.D = Discriminator( image_size, network_capacity, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, transparent=transparent, ) self.SE = StyleVectorizer(latent_dim, style_depth) self.GE = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) set_requires_grad(self.SE, False) set_requires_grad(self.GE, False) generator_params = list(self.G.parameters()) + list( self.S.parameters()) self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9)) self.D_opt = DiffGrad(self.D.parameters(), lr=self.lr, betas=(0.5, 0.9)) self._init_weights() self.reset_parameter_averaging() self.cuda() if fp16: (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize( [self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level="O2") # experimental contrastive loss discriminator regularization if augment_fn is not None: self.augment_fn = augment_fn else: self.augment_fn = nn.Sequential( nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)), RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.3), RandomApply(nn.Sequential( augs.RandomRotation(180), augs.CenterCrop(size=(image_size, image_size))), p=0.2), augs.RandomResizedCrop(size=(image_size, image_size)), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), RandomApply(augs.RandomErasing(), p=0.1), ) self.D_cl = (ContrastiveLearner(self.D, image_size, augment_fn=self.augment_fn, fp16=fp16, hidden_layer="flatten") if cl_reg else None)
def __init__( self, net, image_size, hidden_layer=-2, project_hidden=True, project_dim=128, augment_both=True, use_nt_xent_loss=False, augment_fn=None, use_bilinear=False, use_momentum=False, momentum_value=0.999, key_encoder=None, temperature=0.1, batch_size=128, ): super().__init__() self.net = OutputHiddenLayer(net, layer=hidden_layer) DEFAULT_AUG = nn.Sequential( # RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), # augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), augs.RandomVerticalFlip(), augs.RandomSolarize(), augs.RandomPosterize(), augs.RandomSharpness(), augs.RandomEqualize(), augs.RandomRotation(degrees=8.0), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size), p=0.1), ) self.b = batch_size self.h = image_size self.w = image_size self.augment = default(augment_fn, DEFAULT_AUG) self.augment_both = augment_both self.temperature = temperature self.use_nt_xent_loss = use_nt_xent_loss self.project_hidden = project_hidden self.projection = None self.project_dim = project_dim self.use_bilinear = use_bilinear self.bilinear_w = None self.use_momentum = use_momentum self.ema_updater = EMA(momentum_value) self.key_encoder = key_encoder # for accumulating queries and keys across calls self.queries = None self.keys = None random_data = ( ( torch.randn(1, 3, image_size, image_size), torch.randn(1, 3, image_size, image_size), torch.randn(1, 3, image_size, image_size), ), torch.tensor([1]), ) # send a mock image tensor to instantiate parameters self.forward(random_data)
class TestVideoSequential: @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6), (2, 3, 4, 5, 6, 7)]) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_exception(self, shape, data_format, device, dtype): aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), data_format=data_format, same_on_frame=True) with pytest.raises(AssertionError): img = torch.randn(*shape, device=device, dtype=dtype) aug_list(img) @pytest.mark.parametrize( 'augmentation', [ K.RandomAffine(360, p=1.0), K.CenterCrop((3, 3), p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomCrop((5, 5), p=1.0), K.RandomErasing(p=1.0), K.RandomGrayscale(p=1.0), K.RandomHorizontalFlip(p=1.0), K.RandomVerticalFlip(p=1.0), K.RandomPerspective(p=1.0), K.RandomResizedCrop((5, 5), p=1.0), K.RandomRotation(360.0, p=1.0), K.RandomSolarize(p=1.0), K.RandomPosterize(p=1.0), K.RandomSharpness(p=1.0), K.RandomEqualize(p=1.0), K.RandomMotionBlur(3, 35.0, 0.5, p=1.0), K.Normalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), K.Denormalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_augmentation(self, augmentation, data_format, device, dtype): input = torch.randint(255, (1, 3, 3, 5, 6), device=device, dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0 torch.manual_seed(21) aug_list = K.VideoSequential(augmentation, data_format=data_format, same_on_frame=True) reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0) ], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0) ], [K.RandomAffine(360, p=1.0), kornia.color.BgrToRgb()], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0), K.RandomAffine(360, p=0.0) ], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)], [K.RandomAffine(360, p=0.0)], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) @pytest.mark.parametrize('random_apply', [1, (1, 1), (1, ), 10, True, False]) def test_same_on_frame(self, augmentations, data_format, random_apply, device, dtype): aug_list = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=True, random_apply=random_apply) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, :, 0] == output[:, :, 1]).all() assert (output[:, :, 1] == output[:, :, 2]).all() assert (output[:, :, 2] == output[:, :, 3]).all() if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, 0] == output[:, 1]).all() assert (output[:, 1] == output[:, 2]).all() assert (output[:, 2] == output[:, 3]).all() reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [K.RandomAffine(360, p=1.0)], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)], [ K.RandomAffine(360, p=0.0), K.ImageSequential(K.RandomAffine(360, p=0.0)) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_against_sequential(self, augmentations, data_format, device, dtype): aug_list_1 = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=False) aug_list_2 = torch.nn.Sequential(*augmentations) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) torch.manual_seed(0) output_1 = aug_list_1(input) torch.manual_seed(0) if data_format == 'BCTHW': input = input.transpose(1, 2) output_2 = aug_list_2(input.reshape(-1, 3, 5, 6)) output_2 = output_2.view(2, 4, 3, 5, 6) if data_format == 'BCTHW': output_2 = output_2.transpose(1, 2) assert (output_1 == output_2).all(), dict(aug_list_1._params) @pytest.mark.jit @pytest.mark.skip(reason="turn off due to Union Type") def test_jit(self, device, dtype): B, C, D, H, W = 2, 3, 5, 4, 4 img = torch.ones(B, C, D, H, W, device=device, dtype=dtype) op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), same_on_frame=True) op_jit = torch.jit.script(op) assert_close(op(img), op_jit(img))
def get_augmenter(augmenter_type: str, image_size: ImageSizeType, dataset_mean: DatasetStatType, dataset_std: DatasetStatType, padding: PaddingInputType = 1. / 8., pad_if_needed: bool = False, subset_size: int = 2) -> Union[Module, Callable]: """ Args: augmenter_type: augmenter type image_size: (height, width) image size dataset_mean: dataset mean value in CHW dataset_std: dataset standard deviation in CHW padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is used to pad left/right, top/bottom borders, respectively. pad_if_needed: bool flag for RandomCrop "pad_if_needed" option subset_size: number of augmentations used in subset Returns: nn.Module for Kornia augmentation or Callable for torchvision transform """ if not isinstance(padding, tuple): assert isinstance(padding, float) padding = (padding, padding, padding, padding) assert len(padding) == 2 or len(padding) == 4 if len(padding) == 2: # padding of length 2 is used to pad left/right, top/bottom borders, respectively # padding of length 4 is used to pad left, top, right, bottom borders respectively padding = (padding[0], padding[1], padding[0], padding[1]) # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders padding = (int(image_size[1] * padding[0]), int( image_size[0] * padding[1]), int(image_size[1] * padding[2]), int(image_size[0] * padding[3])) augmenter_type = augmenter_type.strip().lower() if augmenter_type == "simple": return nn.Sequential( K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, padding_mode='reflect'), K.RandomHorizontalFlip(p=0.5), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type == "fixed": return nn.Sequential( K.RandomHorizontalFlip(p=0.5), # K.RandomVerticalFlip(p=0.2), K.RandomResizedCrop(size=image_size, scale=(0.8, 1.0), ratio=(1., 1.)), RandomAugmentation(p=0.5, augmentation=F.GaussianBlur2d( kernel_size=(3, 3), sigma=(1.5, 1.5), border_type='constant')), K.ColorJitter(contrast=(0.75, 1.5)), # additive Gaussian noise K.RandomErasing(p=0.1), # Multiply K.RandomAffine(degrees=(-25., 25.), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-8., 8.)), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type in ["validation", "test"]: return nn.Sequential( K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type == "randaugment": return nn.Sequential( K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, padding_mode='reflect'), K.RandomHorizontalFlip(p=0.5), RandAugmentNS(n=subset_size, m=10), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) else: raise NotImplementedError( f"\"{augmenter_type}\" is not a supported augmenter type")
def default(val, def_val): return def_val if val is None else val # augmentation utils class RandomApply(nn.Module): def __init__(self, fn, p): super().__init__() self.fn = fn self.p = p def forward(self, x): if random.random() > self.p: return x return self.fn(x) # default SimCLR augmentation image_size = 256 DEFAULT_AUG = nn.Sequential( RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size))) #color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))) if __name__ == '__main__': meter = AverageMeter()