Esempio n. 1
0
def init_mnist_cifar_transforms(n_channels=1, stage='train'):
    if n_channels == 1:
        norm_mean_std = Normalize((0.1307, ), (0.3081, ))
    elif n_channels == 3:
        norm_mean_std = Normalize((0.4914, 0.4822, 0.4465),
                                  (0.247, 0.243, 0.261))
    else:
        raise ValueError("Not support channels of {}".format(n_channels))

    train_trf = Compose([
        wrap2solt,
        slc.Stream([
            slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5),
            slt.RandomShear(range_x=(-0.05, 0.05), p=0.5),
            slt.RandomRotate(rotation_range=(-5, 5), p=0.5),
            slt.PadTransform(pad_to=34),
            slt.CropTransform(crop_size=32, crop_mode='r')
        ]), unpack_solt,
        ApplyTransform(norm_mean_std)
    ])

    if stage == 'train':
        return train_trf

    test_trf = Compose([
        wrap2solt,
        slt.PadTransform(pad_to=32), unpack_solt,
        ApplyTransform(norm_mean_std)
    ])

    return test_trf
Esempio n. 2
0
def init_mnist_transforms():
    norm_mean_std = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    train_trf = Compose([
        wrap2solt,
        slt.PadTransform(pad_to=32), unpack_solt,
        ApplyTransform(norm_mean_std, [0, 1])
    ])

    test_trf = Compose([
        wrap2solt,
        slt.PadTransform(pad_to=32), unpack_solt,
        ApplyTransform(norm_mean_std, [0, 1])
    ])

    return train_trf, test_trf
Esempio n. 3
0
def test_reflective_padding_cant_be_applied_to_kpts():
    kpts_data = np.array([[0, 0], [0, 1], [1, 0], [2, 0]]).reshape((4, 2))
    kpts = sld.KeyPoints(kpts_data, 3, 4)
    dc = sld.DataContainer((1, kpts), 'LP')
    trf = slt.PadTransform(pad_to=(10, 10), padding='r')
    with pytest.raises(ValueError):
        trf(dc)
Esempio n. 4
0
def get_landmark_transform(config):
    return transforms.Compose([
        # WrapImageLandmarksSOLT(),
        slc.Stream([
            slt.RandomFlip(p=0.5, axis=1),
            slt.RandomScale(range_x=(0.8, 1.2), p=1),
            slt.RandomRotate(rotation_range=(-180, 180), p=0.2),
            slt.RandomProjection(affine_transforms=slc.Stream([
                slt.RandomScale(range_x=(0.8, 1.3), p=1),
                slt.RandomRotate(rotation_range=(-180, 180), p=1),
                slt.RandomShear(range_x=(-0.1, 0.1), range_y=(0, 0), p=0.5),
                slt.RandomShear(range_y=(-0.1, 0.1), range_x=(0, 0), p=0.5),
            ]), v_range=(1e-5, 2e-3), p=0.8),
            slt.PadTransform(int(config.dataset.crop_size * 1.4), padding='z'),
            slt.CropTransform(config.dataset.crop_size, crop_mode='r'),
            slc.SelectiveStream([
                slt.ImageSaltAndPepper(p=1, gain_range=0.01),
                slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)),
                slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)),
                slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5),
                slc.Stream([
                    slt.ImageSaltAndPepper(p=1, gain_range=0.05),
                    slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
                ]),
                slc.Stream([
                    slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
                    slt.ImageSaltAndPepper(p=1, gain_range=0.01),
                ]),
                slc.Stream()
            ]),
            slt.ImageGammaCorrection(p=1, gamma_range=(0.5, 1.5))
        ]),
        SOLTtoHourGlassGSinput(downsample=4, sigma=3),
        ApplyTransformByIndex(transform=dwutils.npg2tens, ids=[0, 1]),
    ])
Esempio n. 5
0
def init_loader(metadata, args, snapshots_root):
    mean_vector, std_vector = session.init_mean_std(snapshots_root, None, None, None)

    norm_trf = tv_transforms.Normalize(mean_vector.tolist(), std_vector.tolist())

    tta_trf = tv_transforms.Compose([
        img_labels2solt,
        slc.Stream([
            slt.PadTransform(pad_to=(700, 700), padding='z'),
            slt.CropTransform(crop_size=(700, 700), crop_mode='c'),
            slt.ResizeTransform(resize_to=(310, 310), interpolation='bicubic'),
            slt.ImageColorTransform(mode='gs2rgb'),
        ], interpolation='bicubic'),
        unpack_solt_data,
        partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0),
        partial(apply_by_index, transform=norm_trf, idx=0),
        partial(apply_by_index, transform=partial(five_crop, size=300), idx=0),
    ])

    dataset = OAProgressionDataset(dataset=args.dataset_root,
                                   split=metadata, trf=tta_trf)

    loader = DataLoader(dataset,
                        batch_size=args.bs,
                        sampler=SequentialSampler(dataset),
                        num_workers=args.n_threads)

    return loader
Esempio n. 6
0
def init_mnist_transforms():
    train_trf = Compose([
        wrap2solt,
        slc.Stream([
            slt.ResizeTransform(resize_to=(64, 64), interpolation='bilinear'),
            slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5),
            slt.RandomShear(range_x=(-0.05, 0.05), p=0.5),
            slt.RandomRotate(rotation_range=(-10, 10), p=0.5),
            # slt.RandomRotate(rotation_range=(-5, 5), p=0.5),
            slt.PadTransform(pad_to=70),
            slt.CropTransform(crop_size=64, crop_mode='r'),
            slt.ImageAdditiveGaussianNoise(p=1.0)
        ]),
        unpack_solt,
        ApplyTransform(Normalize((0.5, ), (0.5, )))
    ])

    test_trf = Compose([
        wrap2solt,
        slt.ResizeTransform(resize_to=(64, 64), interpolation='bilinear'),
        # slt.PadTransform(pad_to=64),
        unpack_solt,
        ApplyTransform(Normalize((0.5, ), (0.5, ))),
    ])

    return train_trf, test_trf
Esempio n. 7
0
def custom_augment(img):
    if len(img.shape) == 3:
        imgs = img.expand_dims(img, axis=0)
    else:
        imgs = img

    out_imgs = []
    for b in range(img.shape[0]):
        img1 = imgs[b, :, :, 0:1].astype(np.uint8)
        img2 = imgs[b, :, :, 1:2].astype(np.uint8)
        tr = Compose([
            wrap2solt,
            slc.Stream([
                slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3),
                slt.RandomRotate(p=1, rotation_range=(-10, 10)),
                slt.PadTransform(pad_to=int(STD_SZ[0] * 1.05)),
                slt.CropTransform(crop_size=STD_SZ[0], crop_mode='r'),
                slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
            ]), unpack_solt,
            ApplyTransform(Normalize((0.5, ), (0.5, )))
        ])

        img1, _ = tr((img1, 0))
        img2, _ = tr((img2, 0))

        out_img = torch.cat((img1, img2), dim=0)
        out_imgs.append(out_img)
    out_imgs = torch.stack(out_imgs, dim=0)
    return out_imgs
Esempio n. 8
0
def get_wrist_fracture_transformation(crop_size):
    return transforms.Compose([
        SplitDataToFunction(wrap_img_target_solt),
        slc.Stream([
            slt.RandomFlip(p=1, axis=1),
            slt.RandomProjection(affine_transforms=slc.Stream([
                slt.RandomScale(range_x=(0.8, 1.2), p=1),
                slt.RandomShear(range_x=(-0.1, 0.1), p=0.5),
                slt.RandomShear(range_y=(-0.1, 0.1), p=0.5),
                slt.RandomRotate(rotation_range=(-10, 10), p=1),
            ]),
                                 v_range=(1e-5, 5e-4),
                                 p=0.8),
            slt.PadTransform(pad_to=(256, 256), padding='z'),
            slt.CropTransform(crop_size, crop_mode='r'),
            slc.SelectiveStream([
                slc.SelectiveStream([
                    slt.ImageSaltAndPepper(p=1, gain_range=0.01),
                    slt.ImageBlur(p=0.5, blur_type='m', k_size=(11, )),
                ]),
                slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5),
            ]),
            slt.ImageGammaCorrection(p=1, gamma_range=(0.5, 1.5)),
        ]),
        DataToFunction(solt_to_img_target),
        ApplyByIndex(transforms.ToTensor(), 0)
    ])
Esempio n. 9
0
def test_padding_img_mask_3x4_5x5(img_3x4, mask_3x4):
    img, mask = img_3x4, mask_3x4
    dc = sld.DataContainer((img, mask), 'IM')
    transf = slt.PadTransform((5, 5))
    res = transf(dc)
    assert (res[0][0].shape[0] == 5) and (res[0][0].shape[1] == 5)
    assert (res[1][0].shape[0] == 5) and (res[1][0].shape[1] == 5)
Esempio n. 10
0
def test_padding_img_mask_2x2_3x3(img_2x2, mask_2x2):
    img, mask = img_2x2, mask_2x2
    dc = sld.DataContainer((img, mask), 'IM')
    transf = slt.PadTransform((3, 3))
    res = transf(dc)
    assert (res[0][0].shape[0] == 3) and (res[0][0].shape[1] == 3)
    assert (res[1][0].shape[0] == 3) and (res[1][0].shape[1] == 3)
Esempio n. 11
0
    def custom_augment(img):

        tr = Compose([
            wrap2solt,
            slc.Stream([
                slt.ResizeTransform(resize_to=(32, 32),
                                    interpolation='bilinear'),
                slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5),
                slt.RandomFlip(axis=1, p=0.5),
                # slt.RandomShear(range_x=(-0.05, 0.05), p=0.5),
                # slt.RandomRotate(rotation_range=(-10, 10), p=0.5),
                slt.RandomRotate(rotation_range=(-5, 5), p=0.5),
                slt.PadTransform(pad_to=36),
                slt.CropTransform(crop_size=32, crop_mode='r'),
                slt.ImageAdditiveGaussianNoise(p=1.0)
            ]),
            unpack_solt,
            ApplyTransform(norm_mean_std)
        ])

        if len(img.shape) == 3:
            imgs = np.expand_dims(img, axis=0)
        elif len(img.shape) == 4:
            imgs = img
        else:
            raise ValueError('Expect num of dims 3 or 4, but got {}'.format(
                len(img.shape)))

        out_imgs = []
        for b in range(imgs.shape[0]):
            _img = imgs[b, :].astype(np.uint8)
            _img, _ = tr((_img, 0))
            out_imgs.append(_img)

        return torch.stack(out_imgs, dim=0)
Esempio n. 12
0
def test_2x2_pad_to_20x20_center_crop_2x2(pad_size, crop_size, img_2x2,
                                          mask_2x2):
    # Setting up the data
    kpts_data = np.array([[0, 0], [0, 1], [1, 1], [1, 0]]).reshape((4, 2))
    kpts = sld.KeyPoints(kpts_data, 2, 2)
    img, mask = img_2x2, mask_2x2

    dc = sld.DataContainer((
        img,
        mask,
        kpts,
    ), 'IMP')

    stream = slc.Stream([
        slt.PadTransform(pad_to=pad_size),
        slt.CropTransform(crop_size=crop_size)
    ])
    res = stream(dc)

    assert (res[0][0].shape[0] == 2) and (res[0][0].shape[1] == 2)
    assert (res[1][0].shape[0] == 2) and (res[1][0].shape[1] == 2)
    assert (res[2][0].H == 2) and (res[2][0].W == 2)

    assert np.array_equal(res[0][0], img)
    assert np.array_equal(res[1][0], mask)
    assert np.array_equal(res[2][0].data, kpts_data)
Esempio n. 13
0
def test_pad_does_not_change_the_data_when_the_image_and_the_mask_are_big(
        pad_size, pad_type, img_3x3, mask_3x3):
    dc = sld.DataContainer((img_3x3, mask_3x3), 'IM')
    trf = slt.PadTransform(pad_to=pad_size, padding=pad_type)
    dc_res = trf(dc)

    np.testing.assert_array_equal(dc_res.data[0], img_3x3)
    np.testing.assert_array_equal(dc_res.data[1], mask_3x3)
Esempio n. 14
0
def get_landmark_transform_kneel(config):
    cutout = slt.ImageCutOut(
        cutout_size=(int(config.dataset.cutout *
                         config.dataset.augs.crop.crop_x),
                     int(config.dataset.cutout *
                         config.dataset.augs.crop.crop_y)),
        p=0.5)
    ppl = transforms.Compose([
        slc.Stream(),
        slc.SelectiveStream(
            [
                slc.Stream([
                    slt.RandomFlip(p=0.5, axis=1),
                    slt.RandomProjection(affine_transforms=slc.Stream([
                        slt.RandomScale(range_x=(0.9, 1.1), p=1),
                        slt.RandomRotate(rotation_range=(-90, 90), p=1),
                        slt.RandomShear(
                            range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5),
                        slt.RandomShear(
                            range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5),
                    ]),
                                         v_range=(1e-5, 2e-3),
                                         p=0.5),
                    # slt.RandomScale(range_x=(0.5, 2.5), p=0.5),
                ]),
                slc.Stream()
            ],
            probs=[0.7, 0.3]),
        slc.Stream([
            slt.PadTransform(
                (config.dataset.augs.pad.pad_x, config.dataset.augs.pad.pad_y),
                padding='z'),
            slt.CropTransform((config.dataset.augs.crop.crop_x,
                               config.dataset.augs.crop.crop_y),
                              crop_mode='r'),
        ]),
        slc.SelectiveStream([
            slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)),
            slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)),
            slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5),
            slc.Stream([
                slt.ImageSaltAndPepper(p=1, gain_range=0.05),
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
            ]),
            slc.Stream([
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
                slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            ]),
            slc.Stream()
        ],
                            n=1),
        slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
        cutout if config.dataset.use_cutout else slc.Stream(),
        DataToFunction(solt_to_img_target),
        ApplyByIndex(transforms.ToTensor(), 0)
    ])
    return ppl
Esempio n. 15
0
def init_transforms(nc=1):
    if nc == 1:
        norm_mean_std = Normalize((0.1307, ), (0.3081, ))
    elif nc == 3:
        norm_mean_std = Normalize((0.4914, 0.4822, 0.4465),
                                  (0.247, 0.243, 0.261))
    else:
        raise ValueError("Not support channels of {}".format(nc))

    train_trf = Compose([
        wrap2solt,
        slc.Stream([
            slt.ResizeTransform(resize_to=(32, 32), interpolation='bilinear'),
            slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5),
            slt.RandomShear(range_x=(-0.05, 0.05), p=0.5),
            slt.RandomRotate(rotation_range=(-10, 10), p=0.5),
            # slt.RandomRotate(rotation_range=(-5, 5), p=0.5),
            slt.PadTransform(pad_to=36),
            slt.CropTransform(crop_size=32, crop_mode='r'),
            slt.ImageAdditiveGaussianNoise(p=1.0)
        ]),
        unpack_solt,
        ApplyTransform(norm_mean_std)
    ])

    test_trf = Compose([
        wrap2solt,
        slt.ResizeTransform(resize_to=(32, 32), interpolation='bilinear'),
        unpack_solt,
        ApplyTransform(norm_mean_std)
    ])

    def custom_augment(img):
        tr = Compose([
            wrap2solt,
            slc.Stream([
                slt.ResizeTransform(resize_to=(32, 32),
                                    interpolation='bilinear'),
                slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5),
                slt.RandomShear(range_x=(-0.05, 0.05), p=0.5),
                slt.RandomRotate(rotation_range=(-10, 10), p=0.5),
                # slt.RandomRotate(rotation_range=(-5, 5), p=0.5),
                slt.PadTransform(pad_to=36),
                slt.CropTransform(crop_size=32, crop_mode='r'),
                slt.ImageAdditiveGaussianNoise(p=1.0)
            ]),
            unpack_solt,
            ApplyTransform(norm_mean_std)
        ])

        img_tr, _ = tr((img, 0))
        return img_tr

    return train_trf, test_trf, custom_augment
Esempio n. 16
0
def init_augs():
    kvs = GlobalKVS()
    args = kvs['args']
    cutout = slt.ImageCutOut(cutout_size=(int(args.cutout * args.crop_x),
                                          int(args.cutout * args.crop_y)),
                             p=0.5)
    # plus-minus 1.3 pixels
    jitter = slt.KeypointsJitter(dx_range=(-0.003, 0.003),
                                 dy_range=(-0.003, 0.003))
    ppl = tvt.Compose([
        jitter if args.use_target_jitter else slc.Stream(),
        slc.SelectiveStream([
            slc.Stream([
                slt.RandomFlip(p=0.5, axis=1),
                slt.RandomProjection(affine_transforms=slc.Stream([
                    slt.RandomScale(range_x=(0.8, 1.3), p=1),
                    slt.RandomRotate(rotation_range=(-90, 90), p=1),
                    slt.RandomShear(
                        range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5),
                ]),
                                     v_range=(1e-5, 2e-3),
                                     p=0.5),
                slt.RandomScale(range_x=(0.5, 2.5), p=0.5),
            ]),
            slc.Stream()
        ],
                            probs=[0.7, 0.3]),
        slc.Stream([
            slt.PadTransform((args.pad_x, args.pad_y), padding='z'),
            slt.CropTransform((args.crop_x, args.crop_y), crop_mode='r'),
        ]),
        slc.SelectiveStream([
            slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)),
            slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)),
            slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5),
            slc.Stream([
                slt.ImageSaltAndPepper(p=1, gain_range=0.05),
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
            ]),
            slc.Stream([
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
                slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            ]),
            slc.Stream()
        ],
                            n=1),
        slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
        cutout if args.use_cutout else slc.Stream(),
        partial(solt2torchhm, downsample=None, sigma=None),
    ])
    kvs.update('train_trf', ppl)
Esempio n. 17
0
def get_landmark_transform_kneel(config):
    cutout = slt.ImageCutOut(cutout_size=(int(config.dataset.cutout * config.dataset.augs.crop.crop_x),
                                          int(config.dataset.cutout * config.dataset.augs.crop.crop_y)),
                             p=0.5)
    # plus-minus 1.3 pixels
    jitter = slt.KeypointsJitter(dx_range=(-0.003, 0.003), dy_range=(-0.003, 0.003))
    ppl = transforms.Compose([
        ColorPaddingWithSide(p=0.05, pad_size=10, side=SIDES.RANDOM, color=(50,100)),
        TriangularMask(p=0.025, arm_lengths=(100, 50), side=SIDES.RANDOM, color=(50,100)),
        TriangularMask(p=0.025, arm_lengths=(50, 100), side=SIDES.RANDOM, color=(50,100)),
        LowVisibilityTransform(p=0.05, alpha=0.15, bgcolor=(50,100)),
        SubSampleUpScale(p=0.01),
        jitter if config.dataset.augs.use_target_jitter else slc.Stream(),
        slc.SelectiveStream([
            slc.Stream([
                slt.RandomFlip(p=0.5, axis=1),
                slt.RandomProjection(affine_transforms=slc.Stream([
                    slt.RandomScale(range_x=(0.9, 1.1), p=1),
                    slt.RandomRotate(rotation_range=(-90, 90), p=1),
                    slt.RandomShear(range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5),
                ]), v_range=(1e-5, 2e-3), p=0.5),
                # slt.RandomScale(range_x=(0.5, 2.5), p=0.5),
            ]),
            slc.Stream()
        ], probs=[0.7, 0.3]),
        slc.Stream([
            slt.PadTransform((config.dataset.augs.pad.pad_x, config.dataset.augs.pad.pad_y), padding='z'),
            slt.CropTransform((config.dataset.augs.crop.crop_x, config.dataset.augs.crop.crop_y), crop_mode='r'),
        ]),
        slc.SelectiveStream([
            slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)),
            slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)),
            slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5),
            slc.Stream([
                slt.ImageSaltAndPepper(p=1, gain_range=0.05),
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
            ]),
            slc.Stream([
                slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)),
                slt.ImageSaltAndPepper(p=1, gain_range=0.01),
            ]),
            slc.Stream()
        ], n=1),
        slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
        cutout if config.dataset.use_cutout else slc.Stream(),
        partial(solt2torchhm, downsample=None, sigma=None),
    ])
    return ppl
Esempio n. 18
0
def init_train_augs(crop_mode='r', pad_mode='r'):
    trf = transforms.Compose([
        img_labels2solt,
        slc.Stream(
            [
                slt.PadTransform(pad_to=(PAD_TO, PAD_TO)),
                slt.RandomFlip(p=0.5, axis=1),  # horizontal flip
                slt.CropTransform(crop_size=(CROP_SIZE, CROP_SIZE),
                                  crop_mode=crop_mode),
            ],
            padding=pad_mode),
        unpack_solt_data,
        partial(apply_by_index, transform=transforms.ToTensor(), idx=0),
    ])
    return trf
Esempio n. 19
0
def init_data_processing(ds):
    kvs = GlobalKVS()

    train_augs = init_train_augs(
        crop_mode='r', pad_mode='r')  # random crop, reflective padding

    dataset = ImageClassificationDataset(ds,
                                         split=kvs['metadata'],
                                         color_space=kvs['args'].color_space,
                                         transformations=train_augs)

    mean_vector, std_vector = trnsfs.init_mean_std(
        dataset=dataset,
        batch_size=kvs['args'].bs,
        n_threads=kvs['args'].n_threads,
        save_mean_std=kvs['args'].snapshots + '/' + kvs['args'].dataset_name,
        color_space=kvs['args'].color_space)

    print('Color space: ', kvs['args'].color_space)

    print(colored('====> ', 'red') + 'Mean:', mean_vector)
    print(colored('====> ', 'red') + 'Std:', std_vector)

    norm_trf = tv_transforms.Normalize(
        torch.from_numpy(mean_vector).float(),
        torch.from_numpy(std_vector).float())

    train_trf = tv_transforms.Compose(
        [train_augs,
         partial(apply_by_index, transform=norm_trf, idx=0)])

    val_trf = tv_transforms.Compose([
        img_labels2solt,
        slc.Stream([
            slt.PadTransform(pad_to=(PAD_TO, PAD_TO)),
            slt.CropTransform(crop_size=(CROP_SIZE, CROP_SIZE),
                              crop_mode='c'),  # center crop
        ]),
        unpack_solt_data,
        partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0),
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])

    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
    kvs.save_pkl(
        os.path.join(kvs['args'].snapshots, kvs['args'].dataset_name,
                     kvs['snapshot_name'], 'session.pkl'))
Esempio n. 20
0
def test_matrix_transforms_state_reset(img_5x5, ignore_state, pipeline):
    n_iter = 50
    if pipeline:
        ppl = slc.Stream([
            slt.RandomRotate(rotation_range=(-180, 180),
                             p=1,
                             ignore_state=ignore_state),
            slt.PadTransform(pad_to=(10, 10)),
        ])
    else:
        ppl = slt.RandomRotate(rotation_range=(-180, 180),
                               p=1,
                               ignore_state=ignore_state)

    img_test = img_5x5.copy()
    img_test[0, 0] = 1
    random.seed(42)

    trf_not_eq = 0
    imgs_not_eq = 0
    for i in range(n_iter):
        dc1 = sld.DataContainer((img_test.copy(), ), 'I')
        dc2 = sld.DataContainer((img_test.copy(), ), 'I')

        dc1_res = ppl(dc1).data[0].squeeze()
        if pipeline:
            trf_state1 = ppl.transforms[0].state_dict[
                'transform_matrix_corrected']
        else:
            trf_state1 = ppl.state_dict['transform_matrix_corrected']

        dc2_res = ppl(dc2).data[0].squeeze()
        if pipeline:
            trf_state2 = ppl.transforms[0].state_dict[
                'transform_matrix_corrected']
        else:
            trf_state2 = ppl.state_dict['transform_matrix_corrected']

        if not np.array_equal(trf_state1, trf_state2):
            trf_not_eq += 1

        if not np.array_equal(dc1_res, dc2_res):
            imgs_not_eq += 1

    random.seed(None)
    assert trf_not_eq > n_iter // 2
    assert imgs_not_eq > n_iter // 2
Esempio n. 21
0
    def __init__(self, snapshot_path, mean_std_path, device='cpu', jit_trace=True, logger=None):
        if logger is None:
            logger = logging.getLogger('Landmark Annotator')

        self.logger = logger

        self.fold_snapshots = glob.glob(os.path.join(snapshot_path, 'fold_*.pth'))
        logger.log(logging.INFO, f'Found {len(self.fold_snapshots)} snapshots to initialize from')
        models = []
        self.device = device
        with open(os.path.join(snapshot_path, 'session.pkl'), 'rb') as f:
            snapshot_session = pickle.load(f)
        logger.log(logging.INFO, 'Read session snapshot')

        snp_args = snapshot_session['args'][0]

        for snp_name in self.fold_snapshots:
            logger.log(logging.INFO, f'Loading {snp_name} to {device}')
            net = init_model_from_args(snp_args)
            snp = torch.load(snp_name, map_location=device)['model']
            net.load_state_dict(snp)
            models.append(net.eval())

        self.net = NFoldInferenceModel(models).to(self.device)
        self.net.eval()
        logger.log(logging.INFO, f'Loaded 5 folds inference model to {device}')
        if jit_trace:
            logger.log(logging.INFO, 'Optimizing with torch.jit.trace')
            dummy = torch.FloatTensor(2, 3, snp_args.crop_x, snp_args.crop_y).to(device=self.device)
            with torch.no_grad():
                self.net = torch.jit.trace(self.net, dummy)
        mean_vector, std_vector = np.load(mean_std_path)

        self.annotator_type = snp_args.annotations
        self.img_spacing = getattr(snp_args, f'{snp_args.annotations}_spacing')

        norm_trf = partial(normalize_channel_wise, mean=mean_vector, std=std_vector)
        norm_trf = partial(apply_by_index, transform=norm_trf, idx=[0, 1])

        self.trf = tvt.Compose([
            partial(wrap_slt, annotator_type=self.annotator_type),
            slc.Stream([
                slt.PadTransform((snp_args.pad_x, snp_args.pad_y), padding='z'),
                slt.CropTransform((snp_args.crop_x, snp_args.crop_y), crop_mode='c'),
            ]),
            partial(unwrap_slt, norm_trf=norm_trf),
        ])
Esempio n. 22
0
def test_matrix_transforms_use_cache_for_different_dc_items_raises_error(
        img_5x5, mask_3x4, pipeline):
    dc = sld.DataContainer((img_5x5, mask_3x4), 'IM')
    if pipeline:
        ppl = slc.Stream([
            slt.RandomRotate(rotation_range=(-180, 180),
                             p=1,
                             ignore_state=False),
            slt.PadTransform(pad_to=(10, 10)),
        ])
    else:
        ppl = slt.RandomRotate(rotation_range=(-180, 180),
                               p=1,
                               ignore_state=False)

    with pytest.raises(ValueError):
        ppl(dc)
Esempio n. 23
0
def test_pad_to_20x20_img_mask_keypoints_3x3_kpts_first(img_3x3, mask_3x3):
    # Setting up the data
    kpts_data = np.array([[0, 0], [0, 2], [2, 2], [2, 0]]).reshape((4, 2))
    kpts = sld.KeyPoints(kpts_data, 3, 3)
    img, mask = img_3x3, mask_3x3

    dc = sld.DataContainer((kpts, img, mask), 'PIM')
    transf = slt.PadTransform((20, 20))
    res = transf(dc)

    assert (res[2][0].shape[0] == 20) and (res[2][0].shape[1] == 20)
    assert (res[1][0].shape[0] == 20) and (res[1][0].shape[1] == 20)
    assert (res[0][0].H == 20) and (res[0][0].W == 20)

    assert np.array_equal(
        res[0][0].data,
        np.array([[8, 8], [8, 10], [10, 10], [10, 8]]).reshape((4, 2)))
Esempio n. 24
0
def init_data_processing():
    kvs = GlobalKVS()

    dataset = LandmarkDataset(data_root=kvs['args'].dataset_root,
                              split=kvs['metadata'],
                              hc_spacing=kvs['args'].hc_spacing,
                              lc_spacing=kvs['args'].lc_spacing,
                              transform=kvs['train_trf'],
                              ann_type=kvs['args'].annotations,
                              image_pad=kvs['args'].img_pad)

    tmp = init_mean_std(snapshots_dir=os.path.join(kvs['args'].workdir,
                                                   'snapshots'),
                        dataset=dataset,
                        batch_size=kvs['args'].bs,
                        n_threads=kvs['args'].n_threads,
                        n_classes=-1)

    if len(tmp) == 3:
        mean_vector, std_vector, class_weights = tmp
    elif len(tmp) == 2:
        mean_vector, std_vector = tmp
    else:
        raise ValueError('Incorrect format of mean/std/class-weights')

    norm_trf = partial(normalize_channel_wise,
                       mean=mean_vector,
                       std=std_vector)

    train_trf = tvt.Compose(
        [kvs['train_trf'],
         partial(apply_by_index, transform=norm_trf, idx=0)])

    val_trf = tvt.Compose([
        slc.Stream([
            slt.PadTransform((kvs['args'].pad_x, kvs['args'].pad_y),
                             padding='z'),
            slt.CropTransform((kvs['args'].crop_x, kvs['args'].crop_y),
                              crop_mode='c'),
        ]),
        partial(solt2torchhm, downsample=None, sigma=None),
        partial(apply_by_index, transform=norm_trf, idx=0)
    ])

    kvs.update('train_trf', train_trf)
    kvs.update('val_trf', val_trf)
Esempio n. 25
0
def init_train_augs():
    trf = transforms.Compose([
        img_labels2solt,
        slc.Stream([
            slt.PadTransform(pad_to=(700, 700)),
            slt.CropTransform(crop_size=(700, 700), crop_mode='c'),
            slt.ResizeTransform((310, 310)),
            slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3),
            slt.RandomRotate(p=1, rotation_range=(-10, 10)),
            slt.CropTransform(crop_size=(300, 300), crop_mode='r'),
            slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
            slt.ImageColorTransform(mode='gs2rgb')
        ], interpolation='bicubic', padding='z'),
        unpack_solt_data,
        partial(apply_by_index, transform=transforms.ToTensor(), idx=0),
    ])
    return trf
Esempio n. 26
0
def init_train_augmentation_pipeline():
    kvs = GlobalKVS()
    ppl = transforms.Compose([
        img_mask2solt,
        slc.Stream([
            slt.RandomFlip(axis=1, p=0.5),
            slt.ImageGammaCorrection(gamma_range=(0.5, 2), p=0.5),
            slt.PadTransform(pad_to=(kvs['args'].crop_x + 1,
                                     kvs['args'].crop_y + 1)),
            slt.CropTransform(crop_size=(kvs['args'].crop_x,
                                         kvs['args'].crop_y),
                              crop_mode='r')
        ]),
        solt2img_mask,
        partial(apply_by_index, transform=gs2tens, idx=[0, 1]),
    ])
    return ppl
Esempio n. 27
0
def test_6x6_pad_to_20x20_center_crop_6x6_kpts_img(img_6x6):
    # Setting up the data
    kpts_data = np.array([[0, 0], [0, 5], [1, 3], [2, 0]]).reshape((4, 2))
    kpts = sld.KeyPoints(kpts_data, 6, 6)
    img = img_6x6

    dc = sld.DataContainer((kpts, img), 'PI')

    stream = slc.Stream(
        [slt.PadTransform((20, 20)),
         slt.CropTransform((6, 6))])
    res = stream(dc)

    assert (res[1][0].shape[0] == 6) and (res[1][0].shape[1] == 6)
    assert (res[0][0].H == 6) and (res[0][0].W == 6)

    assert np.array_equal(res[1][0], img)
    assert np.array_equal(res[0][0].data, kpts_data)
Esempio n. 28
0
def test_different_crop_modes(crop_mode, img_2x2, mask_2x2):
    if crop_mode == 'd':
        with pytest.raises(ValueError):
            slt.CropTransform(crop_size=2, crop_mode=crop_mode)
    else:
        stream = slc.Stream([
            slt.PadTransform(pad_to=20),
            slt.CropTransform(crop_size=2, crop_mode=crop_mode)
        ])
        img, mask = img_2x2, mask_2x2
        dc = sld.DataContainer((
            img,
            mask,
        ), 'IM')
        dc_res = stream(dc)

        for el in dc_res.data:
            assert el.shape[0] == 2
            assert el.shape[1] == 2
Esempio n. 29
0
def custom_augment(img):
    img1 = img[:, :, 0:1].astype(np.uint8)
    img2 = img[:, :, 1:2].astype(np.uint8)
    tr = Compose([
        wrap2solt,
        slc.Stream([
            slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3),
            slt.RandomRotate(p=1, rotation_range=(-10, 10)),
            slt.PadTransform(pad_to=int(STD_SZ[0] * 1.05)),
            slt.CropTransform(crop_size=STD_SZ[0], crop_mode='r'),
            slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)),
        ]), unpack_solt,
        ApplyTransform(Normalize((0.5, ), (0.5, )))
    ])

    img1, _ = tr((img1, 0))
    img2, _ = tr((img2, 0))

    out_img = torch.cat((img1, img2), dim=0)
    return out_img
Esempio n. 30
0
    def custom_augment(img):
        tr = Compose([
            wrap2solt,
            slc.Stream([
                slt.ResizeTransform(resize_to=(32, 32),
                                    interpolation='bilinear'),
                slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5),
                slt.RandomShear(range_x=(-0.05, 0.05), p=0.5),
                slt.RandomRotate(rotation_range=(-10, 10), p=0.5),
                # slt.RandomRotate(rotation_range=(-5, 5), p=0.5),
                slt.PadTransform(pad_to=36),
                slt.CropTransform(crop_size=32, crop_mode='r'),
                slt.ImageAdditiveGaussianNoise(p=1.0)
            ]),
            unpack_solt,
            ApplyTransform(norm_mean_std)
        ])

        img_tr, _ = tr((img, 0))
        return img_tr