def __init__(self, net, image_size, hidden_layer=-2, projection_size=256, projection_hidden_size=4096, augment_fn=None, moving_average_decay=0.99):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            color.Normalize(mean=torch.tensor(
                [0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment = default(augment_fn, DEFAULT_AUG)

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MultiLayerPerceptron(projection_size, projection_size, projection_hidden_size)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size))
Esempio n. 2
0
    def __init__(self, net, image_size=32,
                layer_name_list = [-2],
                 projection_size = 256,
                 projection_hidden_size = 4096,
                 augment_fn = None,
                 moving_average_decay = 0.99,
                 device_ = 'cuda',
                 number_of_classes = 10,
                 mean_data = torch.tensor([0.485, 0.456, 0.406]),
                 std_data = torch.tensor([0.229, 0.224, 0.225])):
        super().__init__()


        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            augs.Normalize(mean=mean_data, std=std_data)
        )

        self.augment = default(augment_fn, DEFAULT_AUG)
        self.device = device_

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer_name_list=layer_name_list).to(self.device)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size).to(self.device)
        self.online_predictor1 = MLP(projection_size, projection_size, 512).to(self.device)
        self.online_predictor2 = MLP(projection_size, projection_size, 512).to(self.device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size).to(self.device))
Esempio n. 3
0
    def __init__(self, opt):
        super().__init__()
        self.wrapped_dataset = create_dataset(opt['dataset'])
        self.cropped_img_size = opt['crop_size']
        self.key1 = opt_get(opt, ['key1'], 'hq')
        self.key2 = opt_get(opt, ['key2'], 'lq')
        for_sr = opt_get(
            opt, ['for_sr'],
            False)  # When set, color alterations and blurs are disabled.

        augmentations = [ \
            augs.RandomHorizontalFlip(),
            augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
        if not for_sr:
            augmentations.extend([
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
                augs.RandomGrayscale(p=0.2),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
            ])
        if opt['normalize']:
            # The paper calls for normalization. Most datasets/models in this repo don't use this.
            # Recommend setting true if you want to train exactly like the paper.
            augmentations.append(
                augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                               std=torch.tensor([0.229, 0.224, 0.225])))
        self.aug = nn.Sequential(*augmentations)
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        project_hidden=True,
        project_dim=128,
        augment_both=True,
        use_nt_xent_loss=False,
        augment_fn=None,
        use_bilinear=False,
        use_momentum=False,
        momentum_value=0.999,
        key_encoder=None,
        temperature=0.1,
        fp16=False,
    ):
        super().__init__()
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
        )

        self.augment = default(augment_fn, DEFAULT_AUG)

        self.augment_both = augment_both

        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # for accumulating queries and keys across calls
        self.queries = None
        self.keys = None

        self.fp16 = fp16

        # send a mock image tensor to instantiate parameters
        init = torch.randn(1, 3, image_size, image_size, device="cuda")
        if self.fp16:
            init = init.half()
        self.forward(init)
Esempio n. 5
0
 def __init__(self, opt):
     super().__init__()
     self.wrapped_dataset = create_dataset(opt['dataset'])
     augmentations = [
         RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
         augs.RandomGrayscale(p=0.2),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
     ]
     self.aug = nn.Sequential(*augmentations)
     self.rrc = RandomSharedRegionCrop(opt['latent_multiple'],
                                       opt_get(opt, ['jitter_range'], 0))
Esempio n. 6
0
def default_aug(image_size: Tuple[int, int] = (360, 360)) -> nn.Module:
    return nn.Sequential(
        aug.ColorJitter(contrast=0.1, brightness=0.1, saturation=0.1, p=0.8),
        aug.RandomVerticalFlip(),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (0.5, 0.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size, scale=(0.5, 1)),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_size = 256,
        projection_hidden_size = 2048,
        augment_fn = None,
        augment_fn2 = None,
        moving_average_decay = 0.99,
        ppm_num_layers = 1,
        ppm_gamma = 2,
        distance_thres = 0.1, # the paper uses 0.7, but that leads to nearly all positive hits. need clarification on how the coordinates are normalized before distance calculation.
        similarity_temperature = 0.3,
        alpha = 1.
    ):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)),
            augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature
        self.alpha = alpha

        self.propagate_pixels = PPM(
            chan = projection_size,
            num_layers = ppm_num_layers,
            gamma = ppm_gamma
        )

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Esempio n. 8
0
def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )
Esempio n. 9
0
 def __init__(self, opt):
     super().__init__()
     self.wrapped_dataset = create_dataset(opt['dataset'])
     self.cropped_img_size = opt['crop_size']
     self.includes_labels = opt['includes_labels']
     augmentations = [ \
         RandomApply(augs.ColorJitter(0.4, 0.4, 0.4, 0.2), p=0.8),
         augs.RandomGrayscale(p=0.2),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)]
     self.aug = nn.Sequential(*augmentations)
     self.rrc = nn.Sequential(*[
         augs.RandomHorizontalFlip(),
         augs.RandomResizedCrop((self.cropped_img_size,
                                 self.cropped_img_size))
     ])
Esempio n. 10
0
 def __init__(self, model, imageSize, embeddingLayer=-2, projectionDim=256, projectionHiddenDim=4096, emaDecay=0.99):
     super(BYOL, self).__init__()
     
     # Default SimCLR augmentations
     self.augment = nn.Sequential(
         RandomApply(augmentation.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
         augmentation.RandomGrayscale(p=0.2),
         augmentation.RandomHorizontalFlip(),
         RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
         augmentation.RandomResizedCrop((imageSize, imageSize)),
         color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
     )
     
     # Initialize models, predictors and EMA
     self.onlineEncoder = ModelWrapper(model, projectionDim, projectionHiddenDim, embeddingLayer)
     self.onlinePredictor = MLP(projectionDim, projectionDim, projectionHiddenDim)
     self.targetEncoder = copy.deepcopy(self.onlineEncoder)
     self.targetEMA = EMA(emaDecay)
Esempio n. 11
0
    def __init__(self,
                 encoder,
                 predictor,
                 image_size,
                 hidden_layer=-2,
                 projection_size=256,
                 projection_hidden_size=4096,
                 augment_fn=None,
                 augment_fn2=None,
                 moving_average_decay=0.99,
                 use_momentum=True):
        super().__init__()

        # default SimCLR augmentation

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            # augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            # augs.RandomResizedCrop((image_size, image_size)),
            # augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        # self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.online_encoder = encoder

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = predictor

        # get device of network and make wrapper same device
        # device = get_module_device(net)
        device = torch.device(2)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Esempio n. 12
0
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        projection_size=256,
        projection_hidden_size=4096,
        moving_average_decay=0.99,
        use_momentum=True,
        structural_mlp=False,
    ):
        super().__init__()

        self.online_encoder = NetWrapper(net,
                                         projection_size,
                                         projection_hidden_size,
                                         layer=hidden_layer,
                                         use_structural_mlp=structural_mlp)

        augmentations = [ \
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size))]
        self.aug = nn.Sequential(*augmentations)
        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device),
                     torch.randn(2, 3, image_size, image_size, device=device))
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        project_hidden=True,
        project_dim=128,
        augment_both=True,
        use_nt_xent_loss=False,
        augment_fn=None,
        use_bilinear=False,
        use_momentum=False,
        momentum_value=0.999,
        key_encoder=None,
        temperature=0.1,
        batch_size=128,
    ):
        super().__init__()
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        DEFAULT_AUG = nn.Sequential(
            # RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            # augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            augs.RandomVerticalFlip(),
            augs.RandomSolarize(),
            augs.RandomPosterize(),
            augs.RandomSharpness(),
            augs.RandomEqualize(),
            augs.RandomRotation(degrees=8.0),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size), p=0.1),
        )
        self.b = batch_size
        self.h = image_size
        self.w = image_size
        self.augment = default(augment_fn, DEFAULT_AUG)

        self.augment_both = augment_both

        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # for accumulating queries and keys across calls
        self.queries = None
        self.keys = None
        random_data = (
            (
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
            ),
            torch.tensor([1]),
        )
        # send a mock image tensor to instantiate parameters
        self.forward(random_data)
Esempio n. 14
0
    def __init__(self,
                 net,
                 image_size,
                 hidden_layer_pixel=-2,
                 hidden_layer_instance=-2,
                 projection_size=256,
                 projection_hidden_size=2048,
                 augment_fn=None,
                 augment_fn2=None,
                 prob_rand_hflip=0.25,
                 moving_average_decay=0.99,
                 ppm_num_layers=1,
                 ppm_gamma=2,
                 distance_thres=0.7,
                 similarity_temperature=0.3,
                 alpha=1.,
                 use_pixpro=True,
                 cutout_ratio_range=(0.6, 0.8),
                 cutout_interpolate_mode='nearest',
                 coord_cutout_interpolate_mode='bilinear'):
        super().__init__()

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomSolarize(p=0.5),
            augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                           std=torch.tensor([0.229, 0.224, 0.225])))

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)
        self.prob_rand_hflip = prob_rand_hflip

        self.online_encoder = NetWrapper(
            net=net,
            projection_size=projection_size,
            projection_hidden_size=projection_hidden_size,
            layer_pixel=hidden_layer_pixel,
            layer_instance=hidden_layer_instance)

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature
        self.alpha = alpha

        self.use_pixpro = use_pixpro

        if use_pixpro:
            self.propagate_pixels = PPM(chan=projection_size,
                                        num_layers=ppm_num_layers,
                                        gamma=ppm_gamma)

        self.cutout_ratio_range = cutout_ratio_range
        self.cutout_interpolate_mode = cutout_interpolate_mode
        self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode

        # instance level predictor
        self.online_predictor = MLP(projection_size, projection_size,
                                    projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Esempio n. 15
0
    def __init__(
        self,
        net,
        image_size,
        hidden_layer_pixel=-2,
        hidden_layer_instance=-2,
        instance_projection_size=256,
        instance_projection_hidden_size=2048,
        pix_projection_size=256,
        pix_projection_hidden_size=2048,
        augment_fn=None,
        augment_fn2=None,
        prob_rand_hflip=0.25,
        moving_average_decay=0.99,
        ppm_num_layers=1,
        ppm_gamma=2,
        distance_thres=0.7,
        similarity_temperature=0.3,
        cutout_ratio_range=(0.6, 0.8),
        cutout_interpolate_mode='nearest',
        coord_cutout_interpolate_mode='bilinear',
        max_latent_dim=None  # When set, this is the number of stochastically extracted pixels from the latent to extract. Must have an integer square root.
    ):
        super().__init__()

        DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.6, 0.6, 0.6, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomSolarize(p=0.5),
            # Normalize left out because it should be done at the model level.
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)
        self.prob_rand_hflip = prob_rand_hflip

        self.online_encoder = NetWrapper(
            net=net,
            instance_projection_size=instance_projection_size,
            instance_projection_hidden_size=instance_projection_hidden_size,
            pix_projection_size=pix_projection_size,
            pix_projection_hidden_size=pix_projection_hidden_size,
            layer_pixel=hidden_layer_pixel,
            layer_instance=hidden_layer_instance)

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.distance_thres = distance_thres
        self.similarity_temperature = similarity_temperature

        # This requirement is due to the way that these are processed, not a hard requirement.
        assert math.sqrt(max_latent_dim) == int(math.sqrt(max_latent_dim))
        self.max_latent_dim = max_latent_dim

        self.propagate_pixels = PPM(chan=pix_projection_size,
                                    num_layers=ppm_num_layers,
                                    gamma=ppm_gamma)

        self.cutout_ratio_range = cutout_ratio_range
        self.cutout_interpolate_mode = cutout_interpolate_mode
        self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode

        # instance level predictor
        self.online_predictor = MLP(instance_projection_size,
                                    instance_projection_size,
                                    instance_projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
Esempio n. 16
0
def get_augmenter(augmenter_type: str,
                  image_size: ImageSizeType,
                  dataset_mean: DatasetStatType,
                  dataset_std: DatasetStatType,
                  padding: PaddingInputType = 1. / 8.,
                  pad_if_needed: bool = False,
                  subset_size: int = 2) -> Union[Module, Callable]:
    """
    
    Args:
        augmenter_type: augmenter type
        image_size: (height, width) image size
        dataset_mean: dataset mean value in CHW
        dataset_std: dataset standard deviation in CHW
        padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided,
            it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is
            used to pad left/right, top/bottom borders, respectively.
        pad_if_needed: bool flag for RandomCrop "pad_if_needed" option
        subset_size: number of augmentations used in subset

    Returns: nn.Module for Kornia augmentation or Callable for torchvision transform

    """
    if not isinstance(padding, tuple):
        assert isinstance(padding, float)
        padding = (padding, padding, padding, padding)

    assert len(padding) == 2 or len(padding) == 4
    if len(padding) == 2:
        # padding of length 2 is used to pad left/right, top/bottom borders, respectively
        # padding of length 4 is used to pad left, top, right, bottom borders respectively
        padding = (padding[0], padding[1], padding[0], padding[1])

    # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders
    padding = (int(image_size[1] * padding[0]), int(
        image_size[0] * padding[1]), int(image_size[1] * padding[2]),
               int(image_size[0] * padding[3]))

    augmenter_type = augmenter_type.strip().lower()

    if augmenter_type == "simple":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type == "fixed":
        return nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomVerticalFlip(p=0.2),
            K.RandomResizedCrop(size=image_size,
                                scale=(0.8, 1.0),
                                ratio=(1., 1.)),
            RandomAugmentation(p=0.5,
                               augmentation=F.GaussianBlur2d(
                                   kernel_size=(3, 3),
                                   sigma=(1.5, 1.5),
                                   border_type='constant')),
            K.ColorJitter(contrast=(0.75, 1.5)),
            # additive Gaussian noise
            K.RandomErasing(p=0.1),
            # Multiply
            K.RandomAffine(degrees=(-25., 25.),
                           translate=(0.2, 0.2),
                           scale=(0.8, 1.2),
                           shear=(-8., 8.)),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    elif augmenter_type in ["validation", "test"]:
        return nn.Sequential(
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)), )

    elif augmenter_type == "randaugment":
        return nn.Sequential(
            K.RandomCrop(size=image_size,
                         padding=padding,
                         pad_if_needed=pad_if_needed,
                         padding_mode='reflect'),
            K.RandomHorizontalFlip(p=0.5),
            RandAugmentNS(n=subset_size, m=10),
            K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
                        std=torch.tensor(dataset_std, dtype=torch.float32)),
        )

    else:
        raise NotImplementedError(
            f"\"{augmenter_type}\" is not a supported augmenter type")
def default(val, def_val):
    return def_val if val is None else val

# augmentation utils

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)


# default SimCLR augmentation
image_size = 256
DEFAULT_AUG = nn.Sequential(
            RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size)))
            #color.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])))



if __name__ == '__main__':
    meter = AverageMeter()
Esempio n. 18
0
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)