コード例 #1
0
    def test_random_distributions_2D(self):
        ### test whether all 4 possible mirrorings occur in approximately equal frquencies in 2D

        batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None)
        batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1)))

        counts = np.zeros(shape=(4,))

        for b in range(self.num_batches):
            batch = next(batch_gen)

            for ix in range(self.batch_size):
                if (batch['data'][ix, :, :, :] == self.cam_left).all():
                    counts[0] = counts[0] + 1

                elif (batch['data'][ix, :, :, :] == self.cam_updown).all():
                    counts[1] = counts[1] + 1

                elif (batch['data'][ix, :, :, :] == self.cam_updown_left).all():
                    counts[2] = counts[2] + 1

                elif (batch['data'][ix, :, :, :] == self.cam).all():
                    counts[3] = counts[3] + 1

        self.assertTrue([1 if (2200 < c < 2800) else 0 for c in counts] == [1]*4, "2D Images were not mirrored along "
                                                                                  "all axes with equal probability. "
                                                                                  "This may also indicate that "
                                                                                  "mirroring is not working")
コード例 #2
0
def main():
    args = get_arguments()
    utils.reproducibility(args, seed)
    # utils.make_dirs(args.save)
    if not os.path.exists(args.save):
        os.makedirs(args.save)
    # training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args,
    training_generator, val_generator, full_volume, affine, dataset = medical_loaders.generate_datasets(args,
                                                                                               path='/data/hejy/MedicalZooPytorch_2cls/datasets')
    model, optimizer = medzoo.create_model(args)

    criterion = DiceLoss(classes=2, skip_index_after=args.classes, weight = torch.tensor([1, 1]).cuda(), sigmoid_normalization=True)
    # criterion = WeightedCrossEntropyLoss()

    if args.cuda:
        model = model.cuda()
    # model.restore_checkpoint(args.pretrained)
    dataloader_train = MICCAI2020_RIBFRAC_DataLoader3D(dataset, args.batchSz, args.dim,  num_threads_in_multithreaded=2)
    tr_transforms = get_train_transform(args.dim)
    training_generator_aug = SingleThreadedAugmenter(dataloader_train, tr_transforms,)
    
    
    trainer = train.Trainer(args, model, criterion, optimizer, train_data_loader=training_generator,
                            valid_data_loader=val_generator, lr_scheduler=None, dataset = dataset, train_data_loader_aug=training_generator_aug)
    trainer.training()
コード例 #3
0
def create_data_gen_pipeline(patient_data, cf, is_training=True):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
    :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """

    # create instance of batch generator as first element in pipeline.
    data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, cf=cf)

    # add transformations to pipeline.
    my_transforms = []
    if is_training:
        mirror_transform = Mirror(axes=cf.mirror_axes)  #np.arange(cf.dim)
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(
            patch_size=cf.patch_size[:cf.dim],
            patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
            do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
            alpha=cf.da_kwargs['alpha'],
            sigma=cf.da_kwargs['sigma'],
            do_rotation=cf.da_kwargs['do_rotation'],
            angle_x=cf.da_kwargs['angle_x'],
            angle_y=cf.da_kwargs['angle_y'],
            angle_z=cf.da_kwargs['angle_z'],
            do_scale=cf.da_kwargs['do_scale'],
            scale=cf.da_kwargs['scale'],
            random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
        my_transforms.append(
            RandomChannelDeleteTransform(cf.droppable_channels,
                                         cf.channel_drop_p))
    else:
        my_transforms.append(
            CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(
        ConvertSegToBoundingBoxCoordinates(
            cf.dim,
            get_rois_from_seg_flag=False,
            class_specific_seg_flag=cf.class_specific_seg_flag))
    all_transforms = Compose(my_transforms)
    if cf.debugging:
        multithreaded_generator = SingleThreadedAugmenter(
            data_gen, all_transforms)
    else:
        multithreaded_generator = MultiThreadedAugmenter(
            data_gen,
            all_transforms,
            num_processes=cf.n_workers,
            seeds=range(cf.n_workers))
    return multithreaded_generator
コード例 #4
0
    def test_segmentations_2D(self):
        ### test whether segmentations are mirrored coherently with images

        batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None)
        batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1)))

        equivalent = True

        for b in range(self.num_batches):
            batch = next(batch_gen)
            for ix in range(self.batch_size):
                if (batch['data'][ix] != batch['seg'][ix]).all():
                    equivalent = False

        self.assertTrue(equivalent, "2D images and seg were not mirrored in the same way (they should though because "
                                    "seg needs to match the corresponding data")
コード例 #5
0
def main():
    # --------- Parse arguments ---------------------------------------------------------------
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='./config_0.yaml',
                        help='Path to the configuration file.')
    parser.add_argument('--dice', action='store_true')
    # be aware of this argument!!!
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('-i', '--inference', default='', type=str, metavar='PATH',
                        help='run inference on data set and save results')

    # 1e-8 works well for lung masks but seems to prevent
    # rapid learning for nodule masks
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=1)
    args = parser.parse_args()
    # ---------- get the config file(config.yaml) --------------------------------------------
    config = get_config(args.config)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = os.path.join('./work', (datestr() + ' ' + config['filename']))
    nll = True
    if config['dice']:
        nll = False

    weight_decay = config['weight_decay']
    num_threads_for_kits19 = config['num_of_threads']
    patch_size = (160, 160, 128)
    num_batch_per_epoch = config['num_batch_per_epoch']
    setproctitle.setproctitle(args.save)
    start_epoch = 1
    # -------- Record best kidney segmentation dice -------------------------------------------
    best_tk = 0.0
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    # Embed attention module
    model = vnet.VNet(elu=False, nll=nll,
                      attention=config['attention'], nclass=3)  # mark
    batch_size = config['ngpu'] * config['batchSz']
    save_iter = config['model_save_iter']
    # batch_size = args.ngpu*args.batchSz
    gpu_ids = range(config['ngpu'])
    # print(gpu_ids)
    model.apply(weights_init)
    # ------- Resume training from saved model -----------------------------------------------
    if config['resume']:
        if os.path.isfile(config['resume']):
            print("=> loading checkpoint '{}'".format(config['resume']))
            checkpoint = torch.load(config['resume'])
            # .tar files
            if config['resume'].endswith('.tar'):
                # print(checkpoint, "tar")
                start_epoch = checkpoint['epoch']
                best_tk = checkpoint['best_tk']
                checkpoint_model = checkpoint['model_state_dict']
                model.load_state_dict(
                    {k.replace('module.', ''): v for k, v in checkpoint_model.items()})
            # .pkl files for the whole model
            else:
                # print(checkpoint, "pkl")
                model.load_state_dict(checkpoint.state_dict())
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config['resume']))
            exit(-1)
    else:
        pass
    # ------- Which loss function to use ------------------------------------------------------
    if nll:
        training = train_bg
        validate = test_bg
        # class_balance = True
    else:
        training = train_bg_dice
        validate = test_bg_dice
        # class_balance = False
    # -----------------------------------------------------------------------------------------
    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    # -------- Set on GPU ---------------------------------------------------------------------
    if args.cuda:
        model = model.cuda()

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    # create the output directory
    os.makedirs(args.save)
    # save the config file to the output folder
    shutil.copy(args.config, os.path.join(args.save, 'config.yaml'))

    # kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    # ------ Load Training and Validation set --------------------------------------------
    preprocessed_folders = "/home/data_share/npy_data/"
    patients = get_list_of_patients(
        preprocessed_data_folder=preprocessed_folders)
    # split num_split cross-validation sets
    # train, val = get_split_deterministic(
    #     patients, fold=0, num_splits=5, random_state=12345)
    train, val = patients[0:147], patients[147:189]

    # VALIDATION DATA CANNOT BE LOADED IN CASE DUE TO THE LARGE SHAPE...
    # PRINT VALIDATION CASES FOR LATER TEST USE!!
    print("Validation cases:\n", val)
    # set max shape for validation set 
    shapes = [Kits2019DataLoader3D.load_patient(
        i)[0].shape[1:] for i in val]
    max_shape = np.max(shapes, 0)
    max_shape = np.max((max_shape, patch_size), 0)
    # data loading + augmentation
    dataloader_train = Kits2019DataLoader3D(
        train, batch_size, patch_size, num_threads_for_kits19)
    dataloader_validation = Kits2019DataLoader3D(
        val, batch_size * 2, patch_size, num_threads_for_kits19)
    tr_transforms = get_train_transform(patch_size, prob=config['prob'])
    # whether to use single/multiThreadedAugmenter ------------------------------------------
    if num_threads_for_kits19 > 1:
        tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, 
                                        num_processes=num_threads_for_kits19,
                                        num_cached_per_queue=3,seeds=None, pin_memory=True)
        val_gen = MultiThreadedAugmenter(dataloader_validation, None,
                                         num_processes=max(1, num_threads_for_kits19//2), 
                                         num_cached_per_queue=1, seeds=None, pin_memory=False)
        
        tr_gen.restart()
        val_gen.restart()
    else:
        tr_gen = SingleThreadedAugmenter(dataloader_train, transform=tr_transforms)
        val_gen = SingleThreadedAugmenter(dataloader_validation, transform=None)
    # ------- Set learning rate scheduler ----------------------------------------------------
    lr_schdl = lr_scheduler.LR_Scheduler(mode=config['lr_policy'], base_lr=config['lr'],
                                         num_epochs=config['nEpochs'], iters_per_epoch=num_batch_per_epoch,
                                         lr_step=config['step_size'], warmup_epochs=config['warmup_epochs'])
    
    # ------ Choose Optimizer ----------------------------------------------------------------
    if config['opt'] == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=config['lr'],
                              momentum=0.99, weight_decay=weight_decay)
    elif config['opt'] == 'adam':
        optimizer = optim.Adam(
            model.parameters(), lr=config['lr'], weight_decay=weight_decay)
    elif config['opt'] == 'rmsprop':
        optimizer = optim.RMSprop(
            model.parameters(), lr=config['lr'], weight_decay=weight_decay)
    lr_plateu = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, verbose=True, threshold=1e-3, patience=5)
    # ------- Apex Mixed Precision Acceleration ----------------------------------------------
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    model = nn.parallel.DataParallel(model, device_ids=gpu_ids)
    # ------- Save training data -------------------------------------------------------------
    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    trainF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')
    testF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n ')
    # ------- Training Pipeline --------------------------------------------------------------
    for epoch in range(start_epoch, config['nEpochs'] + start_epoch):
        torch.cuda.empty_cache()
        training(args, epoch, model, tr_gen, optimizer, trainF, config, lr_schdl)
        torch.cuda.empty_cache()
        print('==>lr decay to:', optimizer.param_groups[0]['lr'])
        print('testing validation set...')
        composite_dice = validate(args, epoch, model, val_gen, optimizer, testF, config, lr_plateu)
        torch.cuda.empty_cache()
        # save model with best result and routinely
        if composite_dice > best_tk or epoch % config['model_save_iter'] == 0:
            # model_name = 'vnet_epoch_step1_' + str(epoch) + '.pkl'
            model_name = 'vnet_step1_' + str(epoch) + '.tar'
            # torch.save(model, os.path.join(args.save, model_name))
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_tk': best_tk
            }, os.path.join(args.save, model_name))
            best_tk = composite_dice
    # ----------------------------------------------------------------------------------------
    trainF.close()
    testF.close()
コード例 #6
0
 def run(self, img_data, seg_data):
     # Create a parser for the batchgenerators module
     data_generator = DataParser(img_data, seg_data)
     # Initialize empty transform list
     transforms = []
     # Add mirror augmentation
     if self.mirror:
         aug_mirror = MirrorTransform(axes=self.config_mirror_axes)
         transforms.append(aug_mirror)
     # Add contrast augmentation
     if self.contrast:
         aug_contrast = ContrastAugmentationTransform(
             self.config_contrast_range,
             preserve_range=True,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_contrast)
     # Add brightness augmentation
     if self.brightness:
         aug_brightness = BrightnessMultiplicativeTransform(
             self.config_brightness_range,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_brightness)
     # Add gamma augmentation
     if self.gamma:
         aug_gamma = GammaTransform(self.config_gamma_range,
                                    invert_image=False,
                                    per_channel=True,
                                    retain_stats=True,
                                    p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gamma)
     # Add gaussian noise augmentation
     if self.gaussian_noise:
         aug_gaussian_noise = GaussianNoiseTransform(
             self.config_gaussian_noise_range,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gaussian_noise)
     # Add spatial transformations as augmentation
     # (rotation, scaling, elastic deformation)
     if self.rotations or self.scaling or self.elastic_deform or \
         self.cropping:
         # Identify patch shape (full image or cropping)
         if self.cropping: patch_shape = self.cropping_patch_shape
         else: patch_shape = img_data[0].shape[0:-1]
         # Assembling the spatial transformation
         aug_spatial_transform = SpatialTransform(
             patch_shape, [i // 2 for i in patch_shape],
             do_elastic_deform=self.elastic_deform,
             alpha=self.config_elastic_deform_alpha,
             sigma=self.config_elastic_deform_sigma,
             do_rotation=self.rotations,
             angle_x=self.config_rotations_angleX,
             angle_y=self.config_rotations_angleY,
             angle_z=self.config_rotations_angleZ,
             do_scale=self.scaling,
             scale=self.config_scaling_range,
             border_mode_data='constant',
             border_cval_data=0,
             border_mode_seg='constant',
             border_cval_seg=0,
             order_data=3,
             order_seg=0,
             p_el_per_sample=self.config_p_per_sample,
             p_rot_per_sample=self.config_p_per_sample,
             p_scale_per_sample=self.config_p_per_sample,
             random_crop=self.cropping)
         # Append spatial transformation to transformation list
         transforms.append(aug_spatial_transform)
     # Compose the batchgenerators transforms
     all_transforms = Compose(transforms)
     # Assemble transforms into a augmentation generator
     augmentation_generator = SingleThreadedAugmenter(
         data_generator, all_transforms)
     # Perform the data augmentation x times (x = cycles)
     aug_img_data = None
     aug_seg_data = None
     for i in range(0, self.cycles):
         # Run the computation process for the data augmentations
         augmentation = next(augmentation_generator)
         # Access augmentated data from the batchgenerators data structure
         if aug_img_data is None and aug_seg_data is None:
             aug_img_data = augmentation["data"]
             aug_seg_data = augmentation["seg"]
         # Concatenate the new data augmentated data with the cached data
         else:
             aug_img_data = np.concatenate(
                 (augmentation["data"], aug_img_data), axis=0)
             aug_seg_data = np.concatenate(
                 (augmentation["seg"], aug_seg_data), axis=0)
     # Transform data from channel-first back to channel-last structure
     # Data structure channel-first 3D:  (batch, channel, x, y, z)
     # Data structure channel-last 3D:   (batch, x, y, z, channel)
     aug_img_data = np.moveaxis(aug_img_data, 1, -1)
     aug_seg_data = np.moveaxis(aug_seg_data, 1, -1)
     # Return augmentated image and segmentation data
     return aug_img_data, aug_seg_data
コード例 #7
0
def get_moreDA_augmentation_ae(dataloader_train,
                               dataloader_val,
                               patch_size,
                               params=default_3D_augmentation_params,
                               border_val_seg=-1,
                               seeds_train=None,
                               seeds_val=None,
                               order_seg=1,
                               order_data=3,
                               deep_supervision_scales=None,
                               soft_ds=False,
                               classes=None,
                               pin_memory=True,
                               anisotropy=False,
                               extra_label_keys=None,
                               val_mode=False,
                               use_conf=False,
                               global_params=None):
    '''
    Work as Dataloader with augmentation
    :return: train_loader, val_loader
        for each iterator, return {'data': (B, D, H, W), 'target': (B, D, H, W)}
    '''
    if global_params:
        for key, value in global_params.items():
            params[key] = value
        print(global_params)

    if not val_mode:
        assert params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []
        if params.get("selected_data_channels") is not None:
            tr_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            tr_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))

        # anistropic setting
        if anisotropy or params.get("dummy_2D"):
            ignore_axes = (0, )
            tr_transforms.append(
                Convert3DTo2DTransform(extra_label_keys=extra_label_keys))
            patch_size = patch_size[1:]  # 2D patch size

            print('Using dummy2d data augmentation')
            params["elastic_deform_alpha"] = (0., 200.)
            params["elastic_deform_sigma"] = (9., 13.)
            params["rotation_x"] = (-180. / 360 * 2. * np.pi,
                                    180. / 360 * 2. * np.pi)
            params["rotation_y"] = (-0. / 360 * 2. * np.pi,
                                    0. / 360 * 2. * np.pi)
            params["rotation_z"] = (-0. / 360 * 2. * np.pi,
                                    0. / 360 * 2. * np.pi)

        else:
            ignore_axes = None

        # 1. Spatial Transform: rotation, scaling
        tr_transforms.append(
            SpatialTransform(
                patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=params.get("do_elastic"),
                alpha=params.get("elastic_deform_alpha"),
                sigma=params.get("elastic_deform_sigma"),
                do_rotation=params.get("do_rotation"),
                angle_x=params.get("rotation_x"),
                angle_y=params.get("rotation_y"),
                angle_z=params.get("rotation_z"),
                p_rot_per_axis=params.get("rotation_p_per_axis"),
                do_scale=params.get("do_scaling"),
                scale=params.get("scale_range"),
                border_mode_data=params.get("border_mode_data"),
                border_cval_data=0,
                order_data=order_data,
                border_mode_seg="constant",
                border_cval_seg=border_val_seg,
                order_seg=order_seg,
                random_crop=params.get("random_crop"),
                p_el_per_sample=params.get("p_eldef"),
                p_scale_per_sample=params.get("p_scale"),
                p_rot_per_sample=params.get("p_rot"),
                independent_scale_for_each_axis=params.get(
                    "independent_scale_factor_for_each_axis"),
                extra_label_keys=extra_label_keys,
                do_translate=params.get("do_translate"),
                p_trans=params.get("p_trans"),
                trans_max_shifts=params.get("trans_max_shifts"),
                trans_const_channel=params.get("trans_const_channel"),
            ))

        if anisotropy or params.get("dummy_2D"):
            tr_transforms.append(
                Convert2DTo3DTransform(extra_label_keys=extra_label_keys))
        '''
        # 2. Noise Augmentation: gaussian noise, gaussian blur
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
        tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
                                                   p_per_channel=0.5))

        # 3. Color Augmentation: brightness, constrast, low resolution, gamma_transform
        tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
        if params.get("do_additive_brightness"):
            tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
                                                     params.get("additive_brightness_sigma"),
                                                     True, p_per_sample=params.get("additive_brightness_p_per_sample"),
                                                     p_per_channel=params.get("additive_brightness_p_per_channel")))
        tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
        tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
                                                            p_per_channel=0.5,
                                                            order_downsample=0, order_upsample=3, p_per_sample=0.25,
                                                            ignore_axes=ignore_axes))
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=0.1))  # inverted gamma
        if params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
                               p_per_sample=params["p_gamma"]))
        '''

        # 4. Mirror Transform
        if params.get("do_mirror") or params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(params.get("mirror_axes"),
                                extra_label_keys=extra_label_keys))

        # if params.get("mask_was_used_for_normalization") is not None:
        #     mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
        #     tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))

        tr_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        tr_transforms.append(RenameTransform('data', 'image', True))
        tr_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                tr_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                tr_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))
        toTensorKeys = [
            'image', 'gt'
        ] + extra_label_keys if extra_label_keys is not None else [
            'image', 'gt'
        ]
        tr_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        tr_transforms = Compose(tr_transforms)

        if seeds_train is not None:
            seeds_train = [seeds_train] * params.get('num_threads')
        if use_conf:
            num_threads = 1
            num_cached_per_thread = 1
        else:
            num_threads, num_cached_per_thread = params.get(
                'num_threads'), params.get("num_cached_per_thread")
        batchgenerator_train = MultiThreadedAugmenter(dataloader_train,
                                                      tr_transforms,
                                                      num_threads,
                                                      num_cached_per_thread,
                                                      seeds=seeds_train,
                                                      pin_memory=pin_memory)

        val_transforms = []
        val_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        if params.get("selected_data_channels") is not None:
            val_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            val_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))
        val_transforms.append(RenameTransform('data', 'image', True))
        val_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                val_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                val_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))

        val_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        val_transforms = Compose(val_transforms)

        # batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
        #                                             params.get("num_cached_per_thread"),
        #                                             seeds=seeds_val, pin_memory=pin_memory)
        if seeds_val is not None:
            seeds_val = [seeds_val] * 1
        # batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, 1,
        #                                             params.get("num_cached_per_thread"),
        #                                             seeds=seeds_val, pin_memory=False)
        batchgenerator_val = SingleThreadedAugmenter(dataloader_val,
                                                     val_transforms)

    else:
        val_transforms = []
        val_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        if params.get("selected_data_channels") is not None:
            val_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            val_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))
        val_transforms.append(RenameTransform('data', 'image', True))
        val_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                val_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                val_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))

        toTensorKeys = [
            'image', 'gt'
        ] + extra_label_keys if extra_label_keys is not None else [
            'image', 'gt'
        ]
        val_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        val_transforms = Compose(val_transforms)
        batchgenerator_val = SingleThreadedAugmenter(dataloader_val,
                                                     val_transforms)
        batchgenerator_train = SingleThreadedAugmenter(dataloader_train,
                                                       val_transforms)

    return batchgenerator_train, batchgenerator_val
コード例 #8
0
ファイル: data_manager.py プロジェクト: muizzk/delira
    def __init__(self, data_loader: BaseDataLoader, transforms,
                 n_process_augmentation, sampler, sampler_queues: list,
                 num_cached_per_queue=2, seeds=None, **kwargs):
        """

        Parameters
        ----------
        data_loader : :class:`BaseDataLoader`
            the dataloader providing the actual data
        transforms : Callable or None
            the transforms to use. Can be single callable or None
        n_process_augmentation : int
            the number of processes to use for augmentation (only necessary if
            not in debug mode)
        sampler : :class:`AbstractSampler`
            the sampler to use; must be used here instead of inside the
            dataloader to avoid duplications and oversampling due to
            multiprocessing
        sampler_queues : list of :class:`multiprocessing.Queue`
            queues to pass the sample indices to the actual dataloader
        num_cached_per_queue : int
            the number of samples to cache per queue (only necessary if not in
            debug mode)
        seeds : int or list
            the seeds for each process (only necessary if not in debug mode)
        **kwargs :
            additional keyword arguments
        """

        self._batchsize = data_loader.batch_size

        # don't use multiprocessing in debug mode
        if get_current_debug_mode():
            augmenter = SingleThreadedAugmenter(data_loader, transforms)

        else:
            assert isinstance(n_process_augmentation, int)
            # no seeds are given -> use default seed of 1
            if seeds is None:
                seeds = 1

            # only an int is gien as seed -> replicate it for each process
            if isinstance(seeds, int):
                seeds = [seeds] * n_process_augmentation

            # avoid same seeds for all processes
            if any([seeds[0] == _seed for _seed in seeds[1:]]):
                for idx in range(len(seeds)):
                    seeds[idx] = seeds[idx] + idx

            augmenter = MultiThreadedAugmenter(
                data_loader, transforms,
                num_processes=n_process_augmentation,
                num_cached_per_queue=num_cached_per_queue,
                seeds=seeds,
                **kwargs)

        self._augmenter = augmenter
        self._sampler = sampler
        self._sampler_queues = sampler_queues
        self._queue_id = 0
コード例 #9
0
def get_moreDA_augmentation(dataloader_train,
                            dataloader_val,
                            patch_size,
                            params=default_3D_augmentation_params,
                            border_val_seg=-1,
                            seeds_train=None,
                            seeds_val=None,
                            order_seg=1,
                            order_data=3,
                            deep_supervision_scales=None,
                            soft_ds=False,
                            classes=None,
                            pin_memory=True):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         p_rot_per_axis=params.get("rotation_p_per_axis"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25),
                                          p_per_sample=0.15))

    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.1))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "cascade_do_cascade_augmentations") is not None and params.get(
                    "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size"),
                        p_per_label=params.get(
                            "cascade_random_binary_transform_p_per_label")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    if config.DEBUG_MODE:  # zhuc
        batchgenerator_train = SingleThreadedAugmenter(dataloader_train,
                                                       tr_transforms)
    else:
        batchgenerator_train = MultiThreadedAugmenter(
            dataloader_train,
            tr_transforms,
            params.get('num_threads'),
            params.get("num_cached_per_thread"),
            seeds=seeds_train,
            pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    if config.DEBUG_MODE:  # zhuc
        batchgenerator_val = SingleThreadedAugmenter(dataloader_val,
                                                     val_transforms)
    else:
        batchgenerator_val = MultiThreadedAugmenter(
            dataloader_val,
            val_transforms,
            max(params.get('num_threads') // 2, 1),
            params.get("num_cached_per_thread"),
            seeds=seeds_val,
            pin_memory=pin_memory)

    return batchgenerator_train, batchgenerator_val