Ejemplo n.º 1
0
    def test_random_flips(self, device, dtype):
        inp = torch.randn(1, 3, 510, 1020, device=device, dtype=dtype)
        bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]],
                            device=device,
                            dtype=dtype)

        expected_bbox_vertical_flip = torch.tensor(
            [[[355, 499], [660, 499], [660, 259], [355, 259]]],
            device=device,
            dtype=dtype)
        expected_bbox_horizontal_flip = torch.tensor(
            [[[664, 10], [359, 10], [359, 250], [664, 250]]],
            device=device,
            dtype=dtype)

        aug_ver = K.AugmentationSequential(K.RandomVerticalFlip(p=1.0),
                                           data_keys=["input", "bbox"],
                                           return_transform=False,
                                           same_on_batch=False)

        aug_hor = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0),
                                           data_keys=["input", "bbox"],
                                           return_transform=False,
                                           same_on_batch=False)

        out_ver = aug_ver(inp.clone(), bbox.clone())
        out_hor = aug_hor(inp.clone(), bbox.clone())

        assert_close(out_ver[1], expected_bbox_vertical_flip)
        assert_close(out_hor[1], expected_bbox_horizontal_flip)
Ejemplo n.º 2
0
def get_train_augmentation_transforms(
    p_flip_vertical=0.5,
    p_flip_horizontal=0.5,
    max_rotation=10.0,
    max_zoom=1.1,
    max_warp=0.2,
    p_affine=0.75,
    max_lighting=0.2,
    p_lighting=0.75,
):
    """
    Build a set of pytorch image Transforms to use during training:

        p_flip_vertical: probability of a vertical flip
        p_flip_horizontal: probability of a horizontal flip
        max_rotation: maximum rotation angle in degrees
        max_zoom: maximum zoom level
        max_warp: perspective warping scale (from 0.0 to 1.0)
        p_affine: probility of rotation, zoom and perspective warping
        max_lighting: maximum scaling of brightness and contrast
    """
    return [
        kaug.RandomVerticalFlip(p=p_flip_vertical),
        kaug.RandomHorizontalFlip(p=p_flip_horizontal),
        kaug.RandomAffine(p=p_affine, degrees=max_rotation, scale=(1.0, max_zoom)),
        kaug.RandomPerspective(p=p_affine, distortion_scale=max_warp),
        kaug.ColorJitter(p=p_lighting, brightness=max_lighting, contrast=max_lighting),
        # TODO: the kornia transforms work on batches and so by default they
        # add a batch dimension. Until I work out how to apply transform by
        # batches (and only while training) I will just keep this here to
        # remove the batch dimension again
        GetItemTransform(),
    ]
Ejemplo n.º 3
0
def generate_kornia_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5])
    std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1])
    if torch.cuda.is_available():
        mean=mean.cuda()
        std=std.cuda()
    train_transforms=[G.Resize((resize,resize))]
    if include_jitter:
        train_transforms.append(K.ColorJitter(brightness=0.4, contrast=0.4,
                                   saturation=0.4, hue=0.1))
    train_transforms.extend([K.RandomHorizontalFlip(p=0.5),
           K.RandomVerticalFlip(p=0.5),
           K.RandomRotation(90),
           K.RandomResizedCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ])
    val_transforms=[G.Resize((resize,resize)),
           K.CenterCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ]
    transforms=dict(train=nn.Sequential(*train_transforms),
                val=nn.Sequential(*val_transforms))
    if torch.cuda.is_available():
        for k in transforms:
            transforms[k]=transforms[k].cuda()
    return transforms
Ejemplo n.º 4
0
def default_aug(image_size: Tuple[int, int] = (360, 360)) -> nn.Module:
    return nn.Sequential(
        aug.ColorJitter(contrast=0.1, brightness=0.1, saturation=0.1, p=0.8),
        aug.RandomVerticalFlip(),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (0.5, 0.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size, scale=(0.5, 1)),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )
Ejemplo n.º 5
0
 def __init__(self, viz: bool = False):
     super().__init__()
     self.viz = viz
     '''self.geometric = [
         K.augmentation.RandomAffine(60., p=0.75),
     ]'''
     self.augmentations = nn.Sequential(
         augmentation.RandomRotation(degrees=30.),
         augmentation.RandomPerspective(distortion_scale=0.4),
         augmentation.RandomResizedCrop((224, 224)),
         augmentation.RandomHorizontalFlip(p=0.5),
         augmentation.RandomVerticalFlip(p=0.5),
         # K.augmentation.GaussianBlur((3, 3), (0.1, 2.0), p=1.0),
         # K.augmentation.ColorJitter(0.01, 0.01, 0.01, 0.01, p=0.25),
     )
     self.denorm = augmentation.Denormalize(Tensor(DATASET_IMAGE_MEAN), Tensor(DATASET_IMAGE_STD))
Ejemplo n.º 6
0
    def __init__(self, utes, mask, ct, length, opt):
        super(TrainDataset, self).__init__()

        self.utes = utes
        self.label = ct
        self.mask = mask
        self.length = length
        self.num_vols = utes.shape[0]

        self.batch_size = opt.trainBatchSize

        self.spatial = nn.Sequential(
            ka.RandomAffine(45,
                            translate=(0.1, 0.1),
                            scale=(0.85, 1.15),
                            shear=(0.1, 0.1),
                            same_on_batch=True),
            ka.RandomVerticalFlip(same_on_batch=True),
            ka.RandomHorizontalFlip(same_on_batch=True))
        self.dim = 2
        self.counter = 0
Ejemplo n.º 7
0
 def __init__(self, probability: float = 0.1):
     self._probability = probability
     self._operation = aug.RandomVerticalFlip(p=probability)
Ejemplo n.º 8
0
class TestVideoSequential:
    @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6),
                                       (2, 3, 4, 5, 6, 7)])
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_exception(self, shape, data_format, device, dtype):
        aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                     data_format=data_format,
                                     same_on_frame=True)
        with pytest.raises(AssertionError):
            img = torch.randn(*shape, device=device, dtype=dtype)
            aug_list(img)

    @pytest.mark.parametrize(
        'augmentation',
        [
            K.RandomAffine(360, p=1.0),
            K.CenterCrop((3, 3), p=1.0),
            K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
            K.RandomCrop((5, 5), p=1.0),
            K.RandomErasing(p=1.0),
            K.RandomGrayscale(p=1.0),
            K.RandomHorizontalFlip(p=1.0),
            K.RandomVerticalFlip(p=1.0),
            K.RandomPerspective(p=1.0),
            K.RandomResizedCrop((5, 5), p=1.0),
            K.RandomRotation(360.0, p=1.0),
            K.RandomSolarize(p=1.0),
            K.RandomPosterize(p=1.0),
            K.RandomSharpness(p=1.0),
            K.RandomEqualize(p=1.0),
            K.RandomMotionBlur(3, 35.0, 0.5, p=1.0),
            K.Normalize(torch.tensor([0.5, 0.5, 0.5]),
                        torch.tensor([0.5, 0.5, 0.5]),
                        p=1.0),
            K.Denormalize(torch.tensor([0.5, 0.5, 0.5]),
                          torch.tensor([0.5, 0.5, 0.5]),
                          p=1.0),
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_augmentation(self, augmentation, data_format, device, dtype):
        input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                              dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
        torch.manual_seed(21)
        aug_list = K.VideoSequential(augmentation,
                                     data_format=data_format,
                                     same_on_frame=True)
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0)
            ],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)
            ],
            [K.RandomAffine(360, p=1.0),
             kornia.color.BgrToRgb()],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0),
                K.RandomAffine(360, p=0.0)
            ],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)],
            [K.RandomAffine(360, p=0.0)],
            [
                K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                K.RandomAffine(360, p=1.0),
                K.RandomMixUp(p=1.0)
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    @pytest.mark.parametrize('random_apply',
                             [1, (1, 1), (1, ), 10, True, False])
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)

    @pytest.mark.parametrize(
        'augmentations',
        [
            [K.RandomAffine(360, p=1.0)],
            [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)],
            [
                K.RandomAffine(360, p=0.0),
                K.ImageSequential(K.RandomAffine(360, p=0.0))
            ],
        ],
    )
    @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"])
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)

    @pytest.mark.jit
    @pytest.mark.skip(reason="turn off due to Union Type")
    def test_jit(self, device, dtype):
        B, C, D, H, W = 2, 3, 5, 4, 4
        img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
        op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                               same_on_frame=True)
        op_jit = torch.jit.script(op)
        assert_close(op(img), op_jit(img))
Ejemplo n.º 9
0
def get_gpu_transforms(augs: DictConfig, mode: str = '2d') -> dict:
    """Makes GPU augmentations from the augs section of a configuration.

    Parameters
    ----------
    augs : DictConfig
        Augmentation parameters
    mode : str, optional
        If '2d', stacks clip in channels. If 3d, returns 5-D tensor, by default '2d'

    Returns
    -------
    xform : dict
        keys: ['train', 'val', 'test']. Values: a nn.Sequential with Kornia augmentations. 
        Example: auged_images = xform['train'](images)
    """
    # input is a tensor of shape N x C x F x H x W
    train_transforms = [ToFloat()]
    val_transforms = [ToFloat()]
    
    kornia_transforms = []
    
    if augs.LR > 0:
        kornia_transforms.append(K.RandomHorizontalFlip(p=augs.LR,
                                                        same_on_batch=False,
                                                        return_transform=False))
    if augs.UD > 0:
        kornia_transforms.append(K.RandomVerticalFlip(p=augs.UD,
                                                     same_on_batch=False, return_transform=False))
    if augs.degrees > 0:
        kornia_transforms.append(K.RandomRotation(augs.degrees))

    if augs.brightness > 0 or augs.contrast > 0 or augs.saturation > 0 or augs.hue > 0:
        kornia_transforms.append(K.ColorJitter(brightness=augs.brightness,
                                              contrast=augs.contrast, 
                                              saturation=augs.saturation, 
                                              hue=augs.hue, 
                                              p=augs.color_p, 
                                              same_on_batch=False, 
                                              return_transform=False))
    if augs.grayscale > 0:
        kornia_transforms.append(K.RandomGrayscale(p=augs.grayscale))
    
    
    norm = NormalizeVideo(mean=augs.normalization.mean,
                          std=augs.normalization.std)
    # kornia_transforms.append(norm)
    
    kornia_transforms = VideoSequential(*kornia_transforms, 
                                        data_format='BCTHW', 
                                        same_on_frame=True)    
    
    train_transforms = [ToFloat(), 
                        kornia_transforms, 
                        norm]
    val_transforms = [ToFloat(), 
                      norm]

    denormalize = []
    if mode == '2d':
        train_transforms.append(StackClipInChannels())
        val_transforms.append(StackClipInChannels())
        denormalize.append(UnstackClip())
    denormalize.append(DenormalizeVideo(mean=augs.normalization.mean,
                                        std=augs.normalization.std))

    train_transforms = nn.Sequential(*train_transforms)
    val_transforms = nn.Sequential(*val_transforms)
    denormalize = nn.Sequential(*denormalize)

    gpu_transforms = dict(train=train_transforms,
                val=val_transforms,
                test=val_transforms,
                denormalize=denormalize)
    log.info('GPU transforms: {}'.format(gpu_transforms))
    return gpu_transforms
    def __init__(
        self,
        net,
        image_size,
        hidden_layer=-2,
        project_hidden=True,
        project_dim=128,
        augment_both=True,
        use_nt_xent_loss=False,
        augment_fn=None,
        use_bilinear=False,
        use_momentum=False,
        momentum_value=0.999,
        key_encoder=None,
        temperature=0.1,
        batch_size=128,
    ):
        super().__init__()
        self.net = OutputHiddenLayer(net, layer=hidden_layer)

        DEFAULT_AUG = nn.Sequential(
            # RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
            # augs.RandomGrayscale(p=0.2),
            augs.RandomHorizontalFlip(),
            augs.RandomVerticalFlip(),
            augs.RandomSolarize(),
            augs.RandomPosterize(),
            augs.RandomSharpness(),
            augs.RandomEqualize(),
            augs.RandomRotation(degrees=8.0),
            RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
            augs.RandomResizedCrop((image_size, image_size), p=0.1),
        )
        self.b = batch_size
        self.h = image_size
        self.w = image_size
        self.augment = default(augment_fn, DEFAULT_AUG)

        self.augment_both = augment_both

        self.temperature = temperature
        self.use_nt_xent_loss = use_nt_xent_loss

        self.project_hidden = project_hidden
        self.projection = None
        self.project_dim = project_dim

        self.use_bilinear = use_bilinear
        self.bilinear_w = None

        self.use_momentum = use_momentum
        self.ema_updater = EMA(momentum_value)
        self.key_encoder = key_encoder

        # for accumulating queries and keys across calls
        self.queries = None
        self.keys = None
        random_data = (
            (
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
                torch.randn(1, 3, image_size, image_size),
            ),
            torch.tensor([1]),
        )
        # send a mock image tensor to instantiate parameters
        self.forward(random_data)