def get_transforms(mode="train", target_size=128): tranform_list = [] if mode == "train": tranform_list = [# CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), MirrorTransform(axes=(1,)), SpatialTransform(patch_size=(target_size,target_size), random_crop=False, patch_center_dist_from_border=target_size // 2, do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.), do_rotation=True, p_rot_per_sample=0.5, angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), scale=(0.5, 1.9), p_scale_per_sample=0.5, border_mode_data="nearest", border_mode_seg="nearest"), ] elif mode == "val": tranform_list = [CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), ] elif mode == "test": tranform_list = [CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), ] tranform_list.append(NumpyToTensor()) return Compose(tranform_list)
def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ # create instance of batch generator as first element in pipeline. data_gen = BatchGenerator(cf, patient_data, **kwargs) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator
def crop_transform_center(): ''' 从图片的正正中心剪切图片块 :return: ''' crop_size = (128, 128) batchgen = my_data_loader.DataLoader(camera(), 4) centerCrop = CenterCropTransform(crop_size=crop_size) multithreaded_generator = MultiThreadedAugmenter(batchgen, Compose([centerCrop]), 4, 2) my_data_loader.plot_batch(multithreaded_generator.__next__())
def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param cities: list of strings or None :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :param random: bool, whether to draw random batches or go through data linearly :return: multithreaded_generator """ data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir, label_density=cf.label_density, data_split=data_split, resolution=cf.resolution, gt_instances=cf.gt_instances, n_batches=n_batches, random=random) my_transforms = [] if do_aug: mirror_transform = MirrorTransform(axes=(3,)) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data'], border_mode_seg=cf.da_kwargs['border_mode_seg'], border_cval_seg=cf.da_kwargs['border_cval_seg']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:])) my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True, retain_stats=cf.da_kwargs['gamma_retain_stats'], p_per_sample=cf.da_kwargs['p_gamma'])) my_transforms.append(AddLossMask(cf.ignore_label)) if cf.label_switches is not None: my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches)) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator
def create_data_gen_pipeline(patient_data, cf, test_pids=None, do_aug=True): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ if test_pids is None: data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, pre_crop_size=cf.pre_crop_size, dim=cf.dim) else: data_gen = TestGenerator(patient_data, batch_size=cf.batch_size, n_batches=None, pre_crop_size=cf.pre_crop_size, test_pids=test_pids, dim=cf.dim) cf.n_workers = 1 my_transforms = [] if do_aug: mirror_transform = Mirror(axes=(2, 3)) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToOnehotTransform(classes=(0, 1, 2))) my_transforms.append(TransposeChannels()) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator
def get_transforms(mode="train", n_channels=1, target_size=128, add_resize=False, add_noise=False, mask_type="", batch_size=16, rotate=True, elastic_deform=True, rnd_crop=False, color_augment=True): tranform_list = [] noise_list = [] if mode == "train": tranform_list = [ FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), ResizeTransform(target_size=(target_size + 1, target_size + 1), order=1, concatenate_list=True), # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)), MirrorTransform(axes=(2, )), ReshapeTransform(new_shape=(1, -1, "h", "w")), SpatialTransform(patch_size=(target_size, target_size), random_crop=rnd_crop, patch_center_dist_from_border=target_size // 2, do_elastic_deform=elastic_deform, alpha=(0., 100.), sigma=(10., 13.), do_rotation=rotate, angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), scale=(0.9, 1.2), border_mode_data="nearest", border_mode_seg="nearest"), ReshapeTransform(new_shape=(batch_size, -1, "h", "w")) ] if color_augment: tranform_list += [ # BrightnessTransform(mu=0, sigma=0.2), BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)) ] tranform_list += [ GaussianNoiseTransform(noise_variance=(0., 0.05)), ClipValueRange(min=-1.5, max=1.5), ] noise_list = [] if mask_type == "gaussian": noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))] elif mode == "val": tranform_list = [ FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), ResizeTransform(target_size=(target_size + 1, target_size + 1), order=1, concatenate_list=True), CenterCropTransform(crop_size=(target_size, target_size)), ClipValueRange(min=-1.5, max=1.5), # BrightnessTransform(mu=0, sigma=0.2), # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)), CopyTransform({"data": "data_clean"}, copy=True) ] noise_list += [] if add_noise: tranform_list = tranform_list + noise_list tranform_list.append(NumpyToTensor()) return Compose(tranform_list)
def main(): # assign global args global args args = parser.parse_args() # make a folder for the experiment general_folder_name = args.output_path try: os.mkdir(general_folder_name) except OSError: pass # create train, test split, return the indices, patients in test_split wont be seen during whole training train_idx, val_idx, test_idx = CreateTrainValTestSplit( HistoFile_path=args.input_file, num_splits=args.num_splits, num_test_folds=args.num_test_folds, num_val_folds=args.num_val_folds, seed=args.seed) IDs = train_idx + val_idx print('size of training set {}'.format(len(train_idx))) print('size of validation set {}'.format(len(val_idx))) print('size of test set {}'.format(len(test_idx))) # data loading Data = ProstataData(args.input_file) #For details on this class see README # train and validate for cv in range(args.cv_start, args.cv_end): best_epoch = 0 train_loss = [] val_loss = [] # define patients for training and validation train_idx, val_idx = split_training(IDs, len_val=62, cv=cv, cv_runs=args.cv_number) oversampling_factor, Slices_total, Natural_probability_tu_slice, Natural_probability_PRO_slice = get_oversampling( Data, train_idx=sorted(train_idx), Batch_Size=args.b, patch_size=args.patch_size) training_batches = Slices_total / args.b lr = args.lr base_lr = args.lr args.seed += 1 print('train_idx', train_idx, len(train_idx)) print('val_idx', val_idx, len(val_idx)) # get class frequencies print('calculating class frequencie') Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\ Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2\ , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \ = get_class_frequencies(Data, train_idx, patch_size=args.patch_size) print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC) print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2) all_ADC = np.float(Background_frequencie_ADC + Prostate_frequencie_ADC + Tumor_frequencie_ADC) all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 + Tumor_frequencie_T2) print all_ADC print all_T2 W1_ADC = 1 / (Background_frequencie_ADC / all_ADC)**0.25 W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC)**0.25 W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC)**0.25 Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC) Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC) Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC) print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC) W1_T2 = 1 / (Background_frequencie_T2 / all_T2)**0.25 W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2)**0.25 W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2)**0.25 Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2) Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2) Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2) print 'Weights T2', Wa_T2, Wb_T2, Wc_T2 weight_T2 = (Wa_T2, Wb_T2, Wc_T2) # define model Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1])) Net_Name = 'UNetPytorch' model = Net.cuda() # model parameter optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=args.weight_decay) criterion_ADC = CrossEntropyLoss2d( weight=torch.FloatTensor(weight_ADC)).cuda() criterion_T2 = CrossEntropyLoss2d( weight=torch.FloatTensor(weight_T2)).cuda() # new folder name for cv folder_name = general_folder_name + '/CV_{}'.format(cv) try: os.mkdir(folder_name) except OSError: pass checkpoint_file = folder_name + '/checkpoint_' + '{}.pth.tar'.format( Net_Name) # augmentation for epoch in range(args.epochs): torch.manual_seed(args.seed + epoch + cv) np.random.seed(epoch + cv) np.random.shuffle(train_idx) if epoch == 0: my_transforms = [] spatial_transform = SpatialTransform( args.patch_size, np.array(args.patch_size) // 2, do_elastic_deform=True, alpha=(100., 450.), sigma=(13., 17.), do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, order_data=3, random_crop=True) resample_transform = ResampleTransform(zoom_range=(0.7, 1.3)) brightness_transform = BrightnessTransform(0.0, 0.1, True) my_transforms.append(resample_transform) my_transforms.append( ContrastAugmentationTransform((0.75, 1.25), True)) my_transforms.append(brightness_transform) my_transforms.append(Mirror((2, 3))) all_transforms = Compose(my_transforms) sometimes_spatial_transforms = RndTransform( spatial_transform, prob=args.p, alternative_transform=CenterCropTransform(args.patch_size)) sometimes_other_transforms = RndTransform(all_transforms, prob=1.0) final_transform = Compose( [sometimes_spatial_transforms, sometimes_other_transforms]) Center_Crop = CenterCropTransform(args.patch_size) if epoch == 30: my_transforms = [] spatial_transform = SpatialTransform( args.patch_size, np.array(args.patch_size) // 2, do_elastic_deform=True, alpha=(0., 250.), sigma=(11., 14.), do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.), do_scale=True, scale=(0.85, 1.15), border_mode_data='constant', border_cval_data=0, order_data=3, random_crop=True) resample_transform = ResampleTransform(zoom_range=(0.8, 1.2)) brightness_transform = BrightnessTransform(0.0, 0.1, True) my_transforms.append(resample_transform) my_transforms.append( ContrastAugmentationTransform((0.85, 1.15), True)) my_transforms.append(brightness_transform) all_transforms = Compose(my_transforms) sometimes_spatial_transforms = RndTransform( spatial_transform, prob=args.p, alternative_transform=CenterCropTransform(args.patch_size)) sometimes_other_transforms = RndTransform(all_transforms, prob=1.0) final_transform = Compose( [sometimes_spatial_transforms, sometimes_other_transforms]) Center_Crop = CenterCropTransform(args.patch_size) if epoch == 50: my_transforms = [] spatial_transform = SpatialTransform( args.patch_size, np.array(args.patch_size) // 2, do_elastic_deform=True, alpha=(0., 150.), sigma=(10., 12.), do_rotation=True, angle_z=(-np.pi / 2., np.pi / 2.), do_scale=True, scale=(0.85, 1.15), border_mode_data='constant', border_cval_data=0, order_data=3, random_crop=False) resample_transform = ResampleTransform(zoom_range=(0.9, 1.1)) brightness_transform = BrightnessTransform(0.0, 0.1, True) my_transforms.append(resample_transform) my_transforms.append( ContrastAugmentationTransform((0.95, 1.05), True)) my_transforms.append(brightness_transform) all_transforms = Compose(my_transforms) sometimes_spatial_transforms = RndTransform( spatial_transform, prob=args.p, alternative_transform=CenterCropTransform(args.patch_size)) sometimes_other_transforms = RndTransform(all_transforms, prob=1.0) final_transform = Compose( [sometimes_spatial_transforms, sometimes_other_transforms]) Center_Crop = CenterCropTransform(args.patch_size) train_loader = BatchGenerator( Data, BATCH_SIZE=args.b, split_idx=train_idx, seed=args.seed, ProbabilityTumorSlices=oversampling_factor, epoch=epoch, ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std=BVAL_std, T2_mean=T2_mean, T2_std=T2_std) val_loader = BatchGenerator(Data, BATCH_SIZE=0, split_idx=val_idx, seed=args.seed, ProbabilityTumorSlices=None, epoch=epoch, test=True, ADC_mean=ADC_mean, ADC_std=ADC_std, BVAL_mean=BVAL_mean, BVAL_std=BVAL_std, T2_mean=T2_mean, T2_std=T2_std) #train on training set train_losses = train(train_loader=train_loader, model=model, optimizer=optimizer, criterion_ADC=criterion_ADC, criterion_T2=criterion_T2, final_transform=final_transform, workers=args.workers, seed=args.seed, training_batches=training_batches) train_loss.append(train_losses) # evaluate on validation set val_losses = validate(val_loader=val_loader, model=model, folder_name=folder_name, criterion_ADC=criterion_ADC, criterion_T2=criterion_T2, split_ixs=val_idx, epoch=epoch, workers=1, Center_Crop=Center_Crop, seed=args.seed) val_loss.append(val_losses) # write TrainingsCSV to folder name TrainingsCSV = pd.DataFrame({ 'train_loss': train_loss, 'val_loss': val_loss }) TrainingsCSV.to_csv(folder_name + '/TrainingsCSV.csv') if val_losses <= min(val_loss): best_epoch = epoch print 'best epoch', epoch save_checkpoint( { 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename=checkpoint_file) optimizer, lr = adjust_learning_rate(optimizer, base_lr, epoch) # delete all output except best epoch clear_image_data(folder_name, best_epoch, epoch)