def custom_augment(img): tr = Compose([ wrap2solt, slc.Stream([ slt.ResizeTransform(resize_to=(32, 32), interpolation='bilinear'), slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5), slt.RandomFlip(axis=1, p=0.5), # slt.RandomShear(range_x=(-0.05, 0.05), p=0.5), # slt.RandomRotate(rotation_range=(-10, 10), p=0.5), slt.RandomRotate(rotation_range=(-5, 5), p=0.5), slt.PadTransform(pad_to=36), slt.CropTransform(crop_size=32, crop_mode='r'), slt.ImageAdditiveGaussianNoise(p=1.0) ]), unpack_solt, ApplyTransform(norm_mean_std) ]) if len(img.shape) == 3: imgs = np.expand_dims(img, axis=0) elif len(img.shape) == 4: imgs = img else: raise ValueError('Expect num of dims 3 or 4, but got {}'.format( len(img.shape))) out_imgs = [] for b in range(imgs.shape[0]): _img = imgs[b, :].astype(np.uint8) _img, _ = tr((_img, 0)) out_imgs.append(_img) return torch.stack(out_imgs, dim=0)
def get_landmark_transform(config): return transforms.Compose([ # WrapImageLandmarksSOLT(), slc.Stream([ slt.RandomFlip(p=0.5, axis=1), slt.RandomScale(range_x=(0.8, 1.2), p=1), slt.RandomRotate(rotation_range=(-180, 180), p=0.2), slt.RandomProjection(affine_transforms=slc.Stream([ slt.RandomScale(range_x=(0.8, 1.3), p=1), slt.RandomRotate(rotation_range=(-180, 180), p=1), slt.RandomShear(range_x=(-0.1, 0.1), range_y=(0, 0), p=0.5), slt.RandomShear(range_y=(-0.1, 0.1), range_x=(0, 0), p=0.5), ]), v_range=(1e-5, 2e-3), p=0.8), slt.PadTransform(int(config.dataset.crop_size * 1.4), padding='z'), slt.CropTransform(config.dataset.crop_size, crop_mode='r'), slc.SelectiveStream([ slt.ImageSaltAndPepper(p=1, gain_range=0.01), slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)), slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)), slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5), slc.Stream([ slt.ImageSaltAndPepper(p=1, gain_range=0.05), slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), ]), slc.Stream([ slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), slt.ImageSaltAndPepper(p=1, gain_range=0.01), ]), slc.Stream() ]), slt.ImageGammaCorrection(p=1, gamma_range=(0.5, 1.5)) ]), SOLTtoHourGlassGSinput(downsample=4, sigma=3), ApplyTransformByIndex(transform=dwutils.npg2tens, ids=[0, 1]), ])
def custom_augment(img): if len(img.shape) == 3: imgs = img.expand_dims(img, axis=0) else: imgs = img out_imgs = [] for b in range(img.shape[0]): img1 = imgs[b, :, :, 0:1].astype(np.uint8) img2 = imgs[b, :, :, 1:2].astype(np.uint8) tr = Compose([ wrap2solt, slc.Stream([ slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3), slt.RandomRotate(p=1, rotation_range=(-10, 10)), slt.PadTransform(pad_to=int(STD_SZ[0] * 1.05)), slt.CropTransform(crop_size=STD_SZ[0], crop_mode='r'), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), ]), unpack_solt, ApplyTransform(Normalize((0.5, ), (0.5, ))) ]) img1, _ = tr((img1, 0)) img2, _ = tr((img2, 0)) out_img = torch.cat((img1, img2), dim=0) out_imgs.append(out_img) out_imgs = torch.stack(out_imgs, dim=0) return out_imgs
def init_mnist_transforms(): train_trf = Compose([ wrap2solt, slc.Stream([ slt.ResizeTransform(resize_to=(64, 64), interpolation='bilinear'), slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5), slt.RandomShear(range_x=(-0.05, 0.05), p=0.5), slt.RandomRotate(rotation_range=(-10, 10), p=0.5), # slt.RandomRotate(rotation_range=(-5, 5), p=0.5), slt.PadTransform(pad_to=70), slt.CropTransform(crop_size=64, crop_mode='r'), slt.ImageAdditiveGaussianNoise(p=1.0) ]), unpack_solt, ApplyTransform(Normalize((0.5, ), (0.5, ))) ]) test_trf = Compose([ wrap2solt, slt.ResizeTransform(resize_to=(64, 64), interpolation='bilinear'), # slt.PadTransform(pad_to=64), unpack_solt, ApplyTransform(Normalize((0.5, ), (0.5, ))), ]) return train_trf, test_trf
def test_2x2_pad_to_20x20_center_crop_2x2(pad_size, crop_size, img_2x2, mask_2x2): # Setting up the data kpts_data = np.array([[0, 0], [0, 1], [1, 1], [1, 0]]).reshape((4, 2)) kpts = sld.KeyPoints(kpts_data, 2, 2) img, mask = img_2x2, mask_2x2 dc = sld.DataContainer(( img, mask, kpts, ), 'IMP') stream = slc.Stream([ slt.PadTransform(pad_to=pad_size), slt.CropTransform(crop_size=crop_size) ]) res = stream(dc) assert (res[0][0].shape[0] == 2) and (res[0][0].shape[1] == 2) assert (res[1][0].shape[0] == 2) and (res[1][0].shape[1] == 2) assert (res[2][0].H == 2) and (res[2][0].W == 2) assert np.array_equal(res[0][0], img) assert np.array_equal(res[1][0], mask) assert np.array_equal(res[2][0].data, kpts_data)
def init_mnist_cifar_transforms(n_channels=1, stage='train'): if n_channels == 1: norm_mean_std = Normalize((0.1307, ), (0.3081, )) elif n_channels == 3: norm_mean_std = Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) else: raise ValueError("Not support channels of {}".format(n_channels)) train_trf = Compose([ wrap2solt, slc.Stream([ slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5), slt.RandomShear(range_x=(-0.05, 0.05), p=0.5), slt.RandomRotate(rotation_range=(-5, 5), p=0.5), slt.PadTransform(pad_to=34), slt.CropTransform(crop_size=32, crop_mode='r') ]), unpack_solt, ApplyTransform(norm_mean_std) ]) if stage == 'train': return train_trf test_trf = Compose([ wrap2solt, slt.PadTransform(pad_to=32), unpack_solt, ApplyTransform(norm_mean_std) ]) return test_trf
def init_loader(metadata, args, snapshots_root): mean_vector, std_vector = session.init_mean_std(snapshots_root, None, None, None) norm_trf = tv_transforms.Normalize(mean_vector.tolist(), std_vector.tolist()) tta_trf = tv_transforms.Compose([ img_labels2solt, slc.Stream([ slt.PadTransform(pad_to=(700, 700), padding='z'), slt.CropTransform(crop_size=(700, 700), crop_mode='c'), slt.ResizeTransform(resize_to=(310, 310), interpolation='bicubic'), slt.ImageColorTransform(mode='gs2rgb'), ], interpolation='bicubic'), unpack_solt_data, partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0), partial(apply_by_index, transform=norm_trf, idx=0), partial(apply_by_index, transform=partial(five_crop, size=300), idx=0), ]) dataset = OAProgressionDataset(dataset=args.dataset_root, split=metadata, trf=tta_trf) loader = DataLoader(dataset, batch_size=args.bs, sampler=SequentialSampler(dataset), num_workers=args.n_threads) return loader
def get_wrist_fracture_transformation(crop_size): return transforms.Compose([ SplitDataToFunction(wrap_img_target_solt), slc.Stream([ slt.RandomFlip(p=1, axis=1), slt.RandomProjection(affine_transforms=slc.Stream([ slt.RandomScale(range_x=(0.8, 1.2), p=1), slt.RandomShear(range_x=(-0.1, 0.1), p=0.5), slt.RandomShear(range_y=(-0.1, 0.1), p=0.5), slt.RandomRotate(rotation_range=(-10, 10), p=1), ]), v_range=(1e-5, 5e-4), p=0.8), slt.PadTransform(pad_to=(256, 256), padding='z'), slt.CropTransform(crop_size, crop_mode='r'), slc.SelectiveStream([ slc.SelectiveStream([ slt.ImageSaltAndPepper(p=1, gain_range=0.01), slt.ImageBlur(p=0.5, blur_type='m', k_size=(11, )), ]), slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5), ]), slt.ImageGammaCorrection(p=1, gamma_range=(0.5, 1.5)), ]), DataToFunction(solt_to_img_target), ApplyByIndex(transforms.ToTensor(), 0) ])
def init_data_processing(): kvs = GlobalKVS() train_augs = init_train_augs() dataset = OAProgressionDataset(dataset=kvs['args'].dataset_root, split=kvs['metadata'], trf=train_augs) mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots, dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads) print(colored('====> ', 'red') + 'Mean:', mean_vector) print(colored('====> ', 'red') + 'Std:', std_vector) norm_trf = tv_transforms.Normalize(mean_vector.tolist(), std_vector.tolist()) train_trf = tv_transforms.Compose([ train_augs, partial(apply_by_index, transform=norm_trf, idx=0) ]) val_trf = tv_transforms.Compose([ img_labels2solt, slc.Stream([ slt.ResizeTransform((310, 310)), slt.CropTransform(crop_size=(300, 300), crop_mode='c'), slt.ImageColorTransform(mode='gs2rgb'), ], interpolation='bicubic'), unpack_solt_data, partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_train_augs(): trf = transforms.Compose([ img_labels2solt, slc.Stream([ slt.PadTransform(pad_to=(700, 700)), slt.CropTransform(crop_size=(700, 700), crop_mode='c'), slt.ResizeTransform((310, 310)), slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3), slt.RandomRotate(p=1, rotation_range=(-10, 10)), slt.CropTransform(crop_size=(300, 300), crop_mode='r'), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), slt.ImageColorTransform(mode='gs2rgb') ], interpolation='bicubic', padding='z'), unpack_solt_data, partial(apply_by_index, transform=transforms.ToTensor(), idx=0), ]) return trf
def __init__(self): self.imgaug_transform = iaa.CropToFixedSize(width=64, height=64) self.augmentor_op = Operations.Crop(probability=1, width=64, height=64, centre=False) self.solt_stream = slc.Stream( [slt.CropTransform(crop_size=(64, 64), crop_mode="r")])
def get_landmark_transform_kneel(config): cutout = slt.ImageCutOut( cutout_size=(int(config.dataset.cutout * config.dataset.augs.crop.crop_x), int(config.dataset.cutout * config.dataset.augs.crop.crop_y)), p=0.5) ppl = transforms.Compose([ slc.Stream(), slc.SelectiveStream( [ slc.Stream([ slt.RandomFlip(p=0.5, axis=1), slt.RandomProjection(affine_transforms=slc.Stream([ slt.RandomScale(range_x=(0.9, 1.1), p=1), slt.RandomRotate(rotation_range=(-90, 90), p=1), slt.RandomShear( range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5), slt.RandomShear( range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5), ]), v_range=(1e-5, 2e-3), p=0.5), # slt.RandomScale(range_x=(0.5, 2.5), p=0.5), ]), slc.Stream() ], probs=[0.7, 0.3]), slc.Stream([ slt.PadTransform( (config.dataset.augs.pad.pad_x, config.dataset.augs.pad.pad_y), padding='z'), slt.CropTransform((config.dataset.augs.crop.crop_x, config.dataset.augs.crop.crop_y), crop_mode='r'), ]), slc.SelectiveStream([ slt.ImageSaltAndPepper(p=1, gain_range=0.01), slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)), slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)), slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5), slc.Stream([ slt.ImageSaltAndPepper(p=1, gain_range=0.05), slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), ]), slc.Stream([ slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), slt.ImageSaltAndPepper(p=1, gain_range=0.01), ]), slc.Stream() ], n=1), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), cutout if config.dataset.use_cutout else slc.Stream(), DataToFunction(solt_to_img_target), ApplyByIndex(transforms.ToTensor(), 0) ]) return ppl
def test_crop_or_cutout_size_are_too_big(img_2x2, cutout_crop_size): dc = sld.DataContainer((img_2x2, ), 'I') trf = slt.CropTransform(crop_size=cutout_crop_size) with pytest.raises(ValueError): trf(dc) trf = slt.ImageCutOut(p=1, cutout_size=cutout_crop_size) with pytest.raises(ValueError): trf(dc)
def test_different_crop_modes(crop_mode, img_2x2, mask_2x2): if crop_mode == 'd': with pytest.raises(ValueError): slt.CropTransform(crop_size=2, crop_mode=crop_mode) else: stream = slc.Stream([ slt.PadTransform(pad_to=20), slt.CropTransform(crop_size=2, crop_mode=crop_mode) ]) img, mask = img_2x2, mask_2x2 dc = sld.DataContainer(( img, mask, ), 'IM') dc_res = stream(dc) for el in dc_res.data: assert el.shape[0] == 2 assert el.shape[1] == 2
def init_transforms(nc=1): if nc == 1: norm_mean_std = Normalize((0.1307, ), (0.3081, )) elif nc == 3: norm_mean_std = Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) else: raise ValueError("Not support channels of {}".format(nc)) train_trf = Compose([ wrap2solt, slc.Stream([ slt.ResizeTransform(resize_to=(32, 32), interpolation='bilinear'), slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5), slt.RandomShear(range_x=(-0.05, 0.05), p=0.5), slt.RandomRotate(rotation_range=(-10, 10), p=0.5), # slt.RandomRotate(rotation_range=(-5, 5), p=0.5), slt.PadTransform(pad_to=36), slt.CropTransform(crop_size=32, crop_mode='r'), slt.ImageAdditiveGaussianNoise(p=1.0) ]), unpack_solt, ApplyTransform(norm_mean_std) ]) test_trf = Compose([ wrap2solt, slt.ResizeTransform(resize_to=(32, 32), interpolation='bilinear'), unpack_solt, ApplyTransform(norm_mean_std) ]) def custom_augment(img): tr = Compose([ wrap2solt, slc.Stream([ slt.ResizeTransform(resize_to=(32, 32), interpolation='bilinear'), slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5), slt.RandomShear(range_x=(-0.05, 0.05), p=0.5), slt.RandomRotate(rotation_range=(-10, 10), p=0.5), # slt.RandomRotate(rotation_range=(-5, 5), p=0.5), slt.PadTransform(pad_to=36), slt.CropTransform(crop_size=32, crop_mode='r'), slt.ImageAdditiveGaussianNoise(p=1.0) ]), unpack_solt, ApplyTransform(norm_mean_std) ]) img_tr, _ = tr((img, 0)) return img_tr return train_trf, test_trf, custom_augment
def __init__(self): self.augmentor_pipeline = Pipeline() self.augmentor_pipeline.add_operation(Operations.Crop(probability=1, width=64, height=64, centre=False)) self.augmentor_pipeline.add_operation( Operations.Resize(probability=1, width=512, height=512, resample_filter="BILINEAR") ) self.imgaug_transform = iaa.Sequential( [iaa.CropToFixedSize(width=64, height=64), iaa.Scale(size=512, interpolation="linear")] ) self.solt_stream = slc.Stream( [slt.CropTransform(crop_size=(64, 64), crop_mode="r"), slt.ResizeTransform(resize_to=(512, 512))] )
def init_augs(): kvs = GlobalKVS() args = kvs['args'] cutout = slt.ImageCutOut(cutout_size=(int(args.cutout * args.crop_x), int(args.cutout * args.crop_y)), p=0.5) # plus-minus 1.3 pixels jitter = slt.KeypointsJitter(dx_range=(-0.003, 0.003), dy_range=(-0.003, 0.003)) ppl = tvt.Compose([ jitter if args.use_target_jitter else slc.Stream(), slc.SelectiveStream([ slc.Stream([ slt.RandomFlip(p=0.5, axis=1), slt.RandomProjection(affine_transforms=slc.Stream([ slt.RandomScale(range_x=(0.8, 1.3), p=1), slt.RandomRotate(rotation_range=(-90, 90), p=1), slt.RandomShear( range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5), ]), v_range=(1e-5, 2e-3), p=0.5), slt.RandomScale(range_x=(0.5, 2.5), p=0.5), ]), slc.Stream() ], probs=[0.7, 0.3]), slc.Stream([ slt.PadTransform((args.pad_x, args.pad_y), padding='z'), slt.CropTransform((args.crop_x, args.crop_y), crop_mode='r'), ]), slc.SelectiveStream([ slt.ImageSaltAndPepper(p=1, gain_range=0.01), slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)), slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)), slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5), slc.Stream([ slt.ImageSaltAndPepper(p=1, gain_range=0.05), slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), ]), slc.Stream([ slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), slt.ImageSaltAndPepper(p=1, gain_range=0.01), ]), slc.Stream() ], n=1), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), cutout if args.use_cutout else slc.Stream(), partial(solt2torchhm, downsample=None, sigma=None), ]) kvs.update('train_trf', ppl)
def get_landmark_transform_kneel(config): cutout = slt.ImageCutOut(cutout_size=(int(config.dataset.cutout * config.dataset.augs.crop.crop_x), int(config.dataset.cutout * config.dataset.augs.crop.crop_y)), p=0.5) # plus-minus 1.3 pixels jitter = slt.KeypointsJitter(dx_range=(-0.003, 0.003), dy_range=(-0.003, 0.003)) ppl = transforms.Compose([ ColorPaddingWithSide(p=0.05, pad_size=10, side=SIDES.RANDOM, color=(50,100)), TriangularMask(p=0.025, arm_lengths=(100, 50), side=SIDES.RANDOM, color=(50,100)), TriangularMask(p=0.025, arm_lengths=(50, 100), side=SIDES.RANDOM, color=(50,100)), LowVisibilityTransform(p=0.05, alpha=0.15, bgcolor=(50,100)), SubSampleUpScale(p=0.01), jitter if config.dataset.augs.use_target_jitter else slc.Stream(), slc.SelectiveStream([ slc.Stream([ slt.RandomFlip(p=0.5, axis=1), slt.RandomProjection(affine_transforms=slc.Stream([ slt.RandomScale(range_x=(0.9, 1.1), p=1), slt.RandomRotate(rotation_range=(-90, 90), p=1), slt.RandomShear(range_x=(-0.1, 0.1), range_y=(-0.1, 0.1), p=0.5), ]), v_range=(1e-5, 2e-3), p=0.5), # slt.RandomScale(range_x=(0.5, 2.5), p=0.5), ]), slc.Stream() ], probs=[0.7, 0.3]), slc.Stream([ slt.PadTransform((config.dataset.augs.pad.pad_x, config.dataset.augs.pad.pad_y), padding='z'), slt.CropTransform((config.dataset.augs.crop.crop_x, config.dataset.augs.crop.crop_y), crop_mode='r'), ]), slc.SelectiveStream([ slt.ImageSaltAndPepper(p=1, gain_range=0.01), slt.ImageBlur(p=1, blur_type='g', k_size=(3, 5)), slt.ImageBlur(p=1, blur_type='m', k_size=(3, 5)), slt.ImageAdditiveGaussianNoise(p=1, gain_range=0.5), slc.Stream([ slt.ImageSaltAndPepper(p=1, gain_range=0.05), slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), ]), slc.Stream([ slt.ImageBlur(p=0.5, blur_type='m', k_size=(3, 5)), slt.ImageSaltAndPepper(p=1, gain_range=0.01), ]), slc.Stream() ], n=1), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), cutout if config.dataset.use_cutout else slc.Stream(), partial(solt2torchhm, downsample=None, sigma=None), ]) return ppl
def init_train_augs(crop_mode='r', pad_mode='r'): trf = transforms.Compose([ img_labels2solt, slc.Stream( [ slt.PadTransform(pad_to=(PAD_TO, PAD_TO)), slt.RandomFlip(p=0.5, axis=1), # horizontal flip slt.CropTransform(crop_size=(CROP_SIZE, CROP_SIZE), crop_mode=crop_mode), ], padding=pad_mode), unpack_solt_data, partial(apply_by_index, transform=transforms.ToTensor(), idx=0), ]) return trf
def init_data_processing(ds): kvs = GlobalKVS() train_augs = init_train_augs( crop_mode='r', pad_mode='r') # random crop, reflective padding dataset = ImageClassificationDataset(ds, split=kvs['metadata'], color_space=kvs['args'].color_space, transformations=train_augs) mean_vector, std_vector = trnsfs.init_mean_std( dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads, save_mean_std=kvs['args'].snapshots + '/' + kvs['args'].dataset_name, color_space=kvs['args'].color_space) print('Color space: ', kvs['args'].color_space) print(colored('====> ', 'red') + 'Mean:', mean_vector) print(colored('====> ', 'red') + 'Std:', std_vector) norm_trf = tv_transforms.Normalize( torch.from_numpy(mean_vector).float(), torch.from_numpy(std_vector).float()) train_trf = tv_transforms.Compose( [train_augs, partial(apply_by_index, transform=norm_trf, idx=0)]) val_trf = tv_transforms.Compose([ img_labels2solt, slc.Stream([ slt.PadTransform(pad_to=(PAD_TO, PAD_TO)), slt.CropTransform(crop_size=(CROP_SIZE, CROP_SIZE), crop_mode='c'), # center crop ]), unpack_solt_data, partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['args'].dataset_name, kvs['snapshot_name'], 'session.pkl'))
def __init__(self, snapshot_path, mean_std_path, device='cpu', jit_trace=True, logger=None): if logger is None: logger = logging.getLogger('Landmark Annotator') self.logger = logger self.fold_snapshots = glob.glob(os.path.join(snapshot_path, 'fold_*.pth')) logger.log(logging.INFO, f'Found {len(self.fold_snapshots)} snapshots to initialize from') models = [] self.device = device with open(os.path.join(snapshot_path, 'session.pkl'), 'rb') as f: snapshot_session = pickle.load(f) logger.log(logging.INFO, 'Read session snapshot') snp_args = snapshot_session['args'][0] for snp_name in self.fold_snapshots: logger.log(logging.INFO, f'Loading {snp_name} to {device}') net = init_model_from_args(snp_args) snp = torch.load(snp_name, map_location=device)['model'] net.load_state_dict(snp) models.append(net.eval()) self.net = NFoldInferenceModel(models).to(self.device) self.net.eval() logger.log(logging.INFO, f'Loaded 5 folds inference model to {device}') if jit_trace: logger.log(logging.INFO, 'Optimizing with torch.jit.trace') dummy = torch.FloatTensor(2, 3, snp_args.crop_x, snp_args.crop_y).to(device=self.device) with torch.no_grad(): self.net = torch.jit.trace(self.net, dummy) mean_vector, std_vector = np.load(mean_std_path) self.annotator_type = snp_args.annotations self.img_spacing = getattr(snp_args, f'{snp_args.annotations}_spacing') norm_trf = partial(normalize_channel_wise, mean=mean_vector, std=std_vector) norm_trf = partial(apply_by_index, transform=norm_trf, idx=[0, 1]) self.trf = tvt.Compose([ partial(wrap_slt, annotator_type=self.annotator_type), slc.Stream([ slt.PadTransform((snp_args.pad_x, snp_args.pad_y), padding='z'), slt.CropTransform((snp_args.crop_x, snp_args.crop_y), crop_mode='c'), ]), partial(unwrap_slt, norm_trf=norm_trf), ])
def init_data_processing(): kvs = GlobalKVS() dataset = LandmarkDataset(data_root=kvs['args'].dataset_root, split=kvs['metadata'], hc_spacing=kvs['args'].hc_spacing, lc_spacing=kvs['args'].lc_spacing, transform=kvs['train_trf'], ann_type=kvs['args'].annotations, image_pad=kvs['args'].img_pad) tmp = init_mean_std(snapshots_dir=os.path.join(kvs['args'].workdir, 'snapshots'), dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads, n_classes=-1) if len(tmp) == 3: mean_vector, std_vector, class_weights = tmp elif len(tmp) == 2: mean_vector, std_vector = tmp else: raise ValueError('Incorrect format of mean/std/class-weights') norm_trf = partial(normalize_channel_wise, mean=mean_vector, std=std_vector) train_trf = tvt.Compose( [kvs['train_trf'], partial(apply_by_index, transform=norm_trf, idx=0)]) val_trf = tvt.Compose([ slc.Stream([ slt.PadTransform((kvs['args'].pad_x, kvs['args'].pad_y), padding='z'), slt.CropTransform((kvs['args'].crop_x, kvs['args'].crop_y), crop_mode='c'), ]), partial(solt2torchhm, downsample=None, sigma=None), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf)
def init_train_augmentation_pipeline(): kvs = GlobalKVS() ppl = transforms.Compose([ img_mask2solt, slc.Stream([ slt.RandomFlip(axis=1, p=0.5), slt.ImageGammaCorrection(gamma_range=(0.5, 2), p=0.5), slt.PadTransform(pad_to=(kvs['args'].crop_x + 1, kvs['args'].crop_y + 1)), slt.CropTransform(crop_size=(kvs['args'].crop_x, kvs['args'].crop_y), crop_mode='r') ]), solt2img_mask, partial(apply_by_index, transform=gs2tens, idx=[0, 1]), ]) return ppl
def test_6x6_pad_to_20x20_center_crop_6x6_kpts_img(img_6x6): # Setting up the data kpts_data = np.array([[0, 0], [0, 5], [1, 3], [2, 0]]).reshape((4, 2)) kpts = sld.KeyPoints(kpts_data, 6, 6) img = img_6x6 dc = sld.DataContainer((kpts, img), 'PI') stream = slc.Stream( [slt.PadTransform((20, 20)), slt.CropTransform((6, 6))]) res = stream(dc) assert (res[1][0].shape[0] == 6) and (res[1][0].shape[1] == 6) assert (res[0][0].H == 6) and (res[0][0].W == 6) assert np.array_equal(res[1][0], img) assert np.array_equal(res[0][0].data, kpts_data)
def custom_augment(img): img1 = img[:, :, 0:1].astype(np.uint8) img2 = img[:, :, 1:2].astype(np.uint8) tr = Compose([ wrap2solt, slc.Stream([ slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3), slt.RandomRotate(p=1, rotation_range=(-10, 10)), slt.PadTransform(pad_to=int(STD_SZ[0] * 1.05)), slt.CropTransform(crop_size=STD_SZ[0], crop_mode='r'), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), ]), unpack_solt, ApplyTransform(Normalize((0.5, ), (0.5, ))) ]) img1, _ = tr((img1, 0)) img2, _ = tr((img2, 0)) out_img = torch.cat((img1, img2), dim=0) return out_img
def custom_augment(img): tr = Compose([ wrap2solt, slc.Stream([ slt.ResizeTransform(resize_to=(32, 32), interpolation='bilinear'), slt.RandomScale(range_x=(0.9, 1.1), same=False, p=0.5), slt.RandomShear(range_x=(-0.05, 0.05), p=0.5), slt.RandomRotate(rotation_range=(-10, 10), p=0.5), # slt.RandomRotate(rotation_range=(-5, 5), p=0.5), slt.PadTransform(pad_to=36), slt.CropTransform(crop_size=32, crop_mode='r'), slt.ImageAdditiveGaussianNoise(p=1.0) ]), unpack_solt, ApplyTransform(norm_mean_std) ]) img_tr, _ = tr((img, 0)) return img_tr
def init_transforms(): train_trf = Compose([ wrap2solt, slc.Stream([ slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3), slt.RandomRotate(p=1, rotation_range=(-10, 10)), slt.PadTransform(pad_to=int(STD_SZ[0] * 1.05)), slt.CropTransform(crop_size=STD_SZ[0], crop_mode='r'), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), ]), unpack_solt, ApplyTransform(Normalize((0.5,), (0.5,))) ]) test_trf = Compose([ wrap2solt, unpack_solt, ApplyTransform(Normalize((0.5,), (0.5,))) ]) return {"train": train_trf, "eval": test_trf}
def init_binary_segmentation_augs(): kvs = GlobalKVS() ppl = tvt.Compose([ img_binary_mask2solt, slc.Stream([ slt.PadTransform(pad_to=(kvs['args'].pad_x, kvs['args'].pad_y)), slt.RandomFlip(axis=1, p=0.5), slt.CropTransform(crop_size=(kvs['args'].crop_x, kvs['args'].crop_y), crop_mode='r'), slt.ImageGammaCorrection(gamma_range=(kvs['args'].gamma_min, kvs['args'].gamma_max), p=0.5), ]), solt2img_binary_mask, partial(apply_by_index, transform=numpy2tens, idx=[0, 1]), ]) kvs.update('train_trf', ppl) return ppl
def test_3x3_pad_to_20x20_center_crop_3x3_shape_stayes_unchanged( img_3x3, mask_3x3): # Setting up the data kpts_data = np.array([[0, 0], [0, 2], [2, 2], [2, 0]]).reshape((4, 2)) kpts = sld.KeyPoints(kpts_data, 3, 3) img, mask = img_3x3, mask_3x3 dc = sld.DataContainer(( img, mask, kpts, ), 'IMP') stream = slc.Stream( [slt.PadTransform((20, 20)), slt.CropTransform((3, 3))]) res = stream(dc) assert (res[0][0].shape[0] == 3) and (res[0][0].shape[1] == 3) assert (res[1][0].shape[0] == 3) and (res[1][0].shape[1] == 3) assert (res[2][0].H == 3) and (res[2][0].W == 3)
def init_transforms(nc=1): if nc == 1: norm_mean_std = Normalize((0.5, ), (0.5, )) elif nc == 3: norm_mean_std = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) else: raise ValueError("Not support channels of {}".format(nc)) train_trf = Compose([ wrap2solt, slc.Stream([ slt.ImageAdditiveGaussianNoise(p=0.5, gain_range=0.3), slt.RandomRotate(p=1, rotation_range=(-10, 10)), slt.PadTransform(pad_to=int(STD_SZ[0] * 1.05)), slt.CropTransform(crop_size=STD_SZ[0], crop_mode='r'), slt.ImageGammaCorrection(p=0.5, gamma_range=(0.5, 1.5)), ]), unpack_solt, ApplyTransform(norm_mean_std) ]) test_trf = Compose([wrap2solt, unpack_solt, ApplyTransform(norm_mean_std)]) return train_trf, test_trf, custom_augment