Beispiel #1
0
def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         MirrorTransform(axes=(1,)),
                         SpatialTransform(patch_size=(target_size,target_size), random_crop=False,
                                          patch_center_dist_from_border=target_size // 2,
                                          do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.),
                                          do_rotation=True, p_rot_per_sample=0.5,
                                          angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8),
                                          scale=(0.5, 1.9), p_scale_per_sample=0.5,
                                          border_mode_data="nearest", border_mode_seg="nearest"),
                         ]


    elif mode == "val":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    elif mode == "test":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
Beispiel #2
0
def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
    :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """

    # create instance of batch generator as first element in pipeline.
    data_gen = BatchGenerator(cf, patient_data, **kwargs)

    my_transforms = []
    if do_aug:
        if cf.da_kwargs["mirror"]:
            mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
            my_transforms.append(mirror_transform)

        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
    all_transforms = Compose(my_transforms)
    # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Beispiel #3
0
def crop_transform_center():
    '''
    从图片的正正中心剪切图片块
    :return:
    '''
    crop_size = (128, 128)
    batchgen = my_data_loader.DataLoader(camera(), 4)
    centerCrop = CenterCropTransform(crop_size=crop_size)
    multithreaded_generator = MultiThreadedAugmenter(batchgen,
                                                     Compose([centerCrop]), 4,
                                                     2)
    my_data_loader.plot_batch(multithreaded_generator.__next__())
def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param cities: list of strings or None
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
    :param test_pids: (optional) list of test patient ids, calls the test generator.
    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :param random: bool, whether to draw random batches or go through data linearly
    :return: multithreaded_generator
    """
    data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir,
                              label_density=cf.label_density, data_split=data_split, resolution=cf.resolution,
                              gt_instances=cf.gt_instances, n_batches=n_batches, random=random)
    my_transforms = []
    if do_aug:
        mirror_transform = MirrorTransform(axes=(3,))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'],
                                             border_mode_data=cf.da_kwargs['border_mode_data'],
                                             border_mode_seg=cf.da_kwargs['border_mode_seg'],
                                             border_cval_seg=cf.da_kwargs['border_cval_seg'])
        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:]))

    my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True,
                                        retain_stats=cf.da_kwargs['gamma_retain_stats'],
                                        p_per_sample=cf.da_kwargs['p_gamma']))
    my_transforms.append(AddLossMask(cf.ignore_label))
    if cf.label_switches is not None:
        my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches))
    all_transforms = Compose(my_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers,
                                                     seeds=range(cf.n_workers))
    return multithreaded_generator
def create_data_gen_pipeline(patient_data, cf, test_pids=None, do_aug=True):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
    :param test_pids: (optional) list of test patient ids, calls the test generator.
    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """
    if test_pids is None:
        data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size,
                                 pre_crop_size=cf.pre_crop_size, dim=cf.dim)
    else:
        data_gen = TestGenerator(patient_data, batch_size=cf.batch_size, n_batches=None,
                                 pre_crop_size=cf.pre_crop_size, test_pids=test_pids, dim=cf.dim)
        cf.n_workers = 1

    my_transforms = []
    if do_aug:
        mirror_transform = Mirror(axes=(2, 3))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToOnehotTransform(classes=(0, 1, 2)))
    my_transforms.append(TransposeChannels())
    all_transforms = Compose(my_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Beispiel #6
0
def get_transforms(mode="train",
                   n_channels=1,
                   target_size=128,
                   add_resize=False,
                   add_noise=False,
                   mask_type="",
                   batch_size=16,
                   rotate=True,
                   elastic_deform=True,
                   rnd_crop=False,
                   color_augment=True):
    tranform_list = []
    noise_list = []

    if mode == "train":

        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),

            # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)),
            MirrorTransform(axes=(2, )),
            ReshapeTransform(new_shape=(1, -1, "h", "w")),
            SpatialTransform(patch_size=(target_size, target_size),
                             random_crop=rnd_crop,
                             patch_center_dist_from_border=target_size // 2,
                             do_elastic_deform=elastic_deform,
                             alpha=(0., 100.),
                             sigma=(10., 13.),
                             do_rotation=rotate,
                             angle_x=(-0.1, 0.1),
                             angle_y=(0, 1e-8),
                             angle_z=(0, 1e-8),
                             scale=(0.9, 1.2),
                             border_mode_data="nearest",
                             border_mode_seg="nearest"),
            ReshapeTransform(new_shape=(batch_size, -1, "h", "w"))
        ]
        if color_augment:
            tranform_list += [  # BrightnessTransform(mu=0, sigma=0.2),
                BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1))
            ]

        tranform_list += [
            GaussianNoiseTransform(noise_variance=(0., 0.05)),
            ClipValueRange(min=-1.5, max=1.5),
        ]

        noise_list = []
        if mask_type == "gaussian":
            noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))]

    elif mode == "val":
        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),
            CenterCropTransform(crop_size=(target_size, target_size)),
            ClipValueRange(min=-1.5, max=1.5),
            # BrightnessTransform(mu=0, sigma=0.2),
            # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)),
            CopyTransform({"data": "data_clean"}, copy=True)
        ]

        noise_list += []

    if add_noise:
        tranform_list = tranform_list + noise_list

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)
Beispiel #7
0
def main():

    # assign global args
    global args
    args = parser.parse_args()

    # make a folder for the experiment
    general_folder_name = args.output_path
    try:
        os.mkdir(general_folder_name)
    except OSError:
        pass

    # create train, test split, return the indices, patients in test_split wont be seen during whole training
    train_idx, val_idx, test_idx = CreateTrainValTestSplit(
        HistoFile_path=args.input_file,
        num_splits=args.num_splits,
        num_test_folds=args.num_test_folds,
        num_val_folds=args.num_val_folds,
        seed=args.seed)

    IDs = train_idx + val_idx

    print('size of training set {}'.format(len(train_idx)))
    print('size of validation set {}'.format(len(val_idx)))
    print('size of test set {}'.format(len(test_idx)))

    # data loading
    Data = ProstataData(args.input_file)  #For details on this class see README

    # train and validate
    for cv in range(args.cv_start, args.cv_end):
        best_epoch = 0
        train_loss = []
        val_loss = []

        # define patients for training and validation
        train_idx, val_idx = split_training(IDs,
                                            len_val=62,
                                            cv=cv,
                                            cv_runs=args.cv_number)

        oversampling_factor, Slices_total, Natural_probability_tu_slice, Natural_probability_PRO_slice = get_oversampling(
            Data,
            train_idx=sorted(train_idx),
            Batch_Size=args.b,
            patch_size=args.patch_size)

        training_batches = Slices_total / args.b

        lr = args.lr
        base_lr = args.lr
        args.seed += 1

        print('train_idx', train_idx, len(train_idx))
        print('val_idx', val_idx, len(val_idx))

        # get class frequencies
        print('calculating class frequencie')

        Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\
        Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2\
        , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \
            = get_class_frequencies(Data, train_idx, patch_size=args.patch_size)

        print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std

        print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC,
              Background_frequencie_ADC)
        print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2,
              Background_frequencie_T2)

        all_ADC = np.float(Background_frequencie_ADC +
                           Prostate_frequencie_ADC + Tumor_frequencie_ADC)
        all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 +
                          Tumor_frequencie_T2)

        print all_ADC
        print all_T2

        W1_ADC = 1 / (Background_frequencie_ADC / all_ADC)**0.25
        W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC)**0.25
        W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC)**0.25

        Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC)
        Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC)
        Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC)

        print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC

        weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC)

        W1_T2 = 1 / (Background_frequencie_T2 / all_T2)**0.25
        W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2)**0.25
        W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2)**0.25

        Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2)
        Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2)
        Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2)

        print 'Weights T2', Wa_T2, Wb_T2, Wc_T2

        weight_T2 = (Wa_T2, Wb_T2, Wc_T2)

        # define model
        Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1]))
        Net_Name = 'UNetPytorch'
        model = Net.cuda()

        # model parameter
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr,
                                     weight_decay=args.weight_decay)
        criterion_ADC = CrossEntropyLoss2d(
            weight=torch.FloatTensor(weight_ADC)).cuda()
        criterion_T2 = CrossEntropyLoss2d(
            weight=torch.FloatTensor(weight_T2)).cuda()

        # new folder name for cv
        folder_name = general_folder_name + '/CV_{}'.format(cv)
        try:
            os.mkdir(folder_name)
        except OSError:
            pass

        checkpoint_file = folder_name + '/checkpoint_' + '{}.pth.tar'.format(
            Net_Name)

        # augmentation
        for epoch in range(args.epochs):
            torch.manual_seed(args.seed + epoch + cv)
            np.random.seed(epoch + cv)
            np.random.shuffle(train_idx)

            if epoch == 0:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(100., 450.),
                    sigma=(13., 17.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.75, 1.25),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=True)
                resample_transform = ResampleTransform(zoom_range=(0.7, 1.3))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.75, 1.25), True))
                my_transforms.append(brightness_transform)
                my_transforms.append(Mirror((2, 3)))
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            if epoch == 30:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(0., 250.),
                    sigma=(11., 14.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.85, 1.15),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=True)
                resample_transform = ResampleTransform(zoom_range=(0.8, 1.2))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.85, 1.15), True))
                my_transforms.append(brightness_transform)
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            if epoch == 50:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(0., 150.),
                    sigma=(10., 12.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.85, 1.15),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=False)
                resample_transform = ResampleTransform(zoom_range=(0.9, 1.1))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.95, 1.05), True))
                my_transforms.append(brightness_transform)
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            train_loader = BatchGenerator(
                Data,
                BATCH_SIZE=args.b,
                split_idx=train_idx,
                seed=args.seed,
                ProbabilityTumorSlices=oversampling_factor,
                epoch=epoch,
                ADC_mean=ADC_mean,
                ADC_std=ADC_std,
                BVAL_mean=BVAL_mean,
                BVAL_std=BVAL_std,
                T2_mean=T2_mean,
                T2_std=T2_std)

            val_loader = BatchGenerator(Data,
                                        BATCH_SIZE=0,
                                        split_idx=val_idx,
                                        seed=args.seed,
                                        ProbabilityTumorSlices=None,
                                        epoch=epoch,
                                        test=True,
                                        ADC_mean=ADC_mean,
                                        ADC_std=ADC_std,
                                        BVAL_mean=BVAL_mean,
                                        BVAL_std=BVAL_std,
                                        T2_mean=T2_mean,
                                        T2_std=T2_std)

            #train on training set
            train_losses = train(train_loader=train_loader,
                                 model=model,
                                 optimizer=optimizer,
                                 criterion_ADC=criterion_ADC,
                                 criterion_T2=criterion_T2,
                                 final_transform=final_transform,
                                 workers=args.workers,
                                 seed=args.seed,
                                 training_batches=training_batches)
            train_loss.append(train_losses)

            # evaluate on validation set
            val_losses = validate(val_loader=val_loader,
                                  model=model,
                                  folder_name=folder_name,
                                  criterion_ADC=criterion_ADC,
                                  criterion_T2=criterion_T2,
                                  split_ixs=val_idx,
                                  epoch=epoch,
                                  workers=1,
                                  Center_Crop=Center_Crop,
                                  seed=args.seed)
            val_loss.append(val_losses)

            # write TrainingsCSV to folder name
            TrainingsCSV = pd.DataFrame({
                'train_loss': train_loss,
                'val_loss': val_loss
            })
            TrainingsCSV.to_csv(folder_name + '/TrainingsCSV.csv')

            if val_losses <= min(val_loss):
                best_epoch = epoch
                print 'best epoch', epoch
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    filename=checkpoint_file)

            optimizer, lr = adjust_learning_rate(optimizer, base_lr, epoch)

        # delete all output except best epoch
        clear_image_data(folder_name, best_epoch, epoch)