def create_data_gen_pipeline(patient_data, cf, test_pids=None, do_aug=True):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
    :param test_pids: (optional) list of test patient ids, calls the test generator.
    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """
    if test_pids is None:
        data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size,
                                 pre_crop_size=cf.pre_crop_size, dim=cf.dim)
    else:
        data_gen = TestGenerator(patient_data, batch_size=cf.batch_size, n_batches=None,
                                 pre_crop_size=cf.pre_crop_size, test_pids=test_pids, dim=cf.dim)
        cf.n_workers = 1

    my_transforms = []
    if do_aug:
        mirror_transform = Mirror(axes=(2, 3))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToOnehotTransform(classes=(0, 1, 2)))
    my_transforms.append(TransposeChannels())
    all_transforms = Compose(my_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Beispiel #2
0
def create_data_gen_pipeline(patient_data, cf, is_training=True):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
    :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """

    # create instance of batch generator as first element in pipeline.
    data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, cf=cf)

    # add transformations to pipeline.
    my_transforms = []
    if is_training:
        mirror_transform = Mirror(axes=np.arange(cf.dim))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
        # print('debug spatial_transform, my_transforms')
        # import IPython;IPython.embed()
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
    all_transforms = Compose(my_transforms)
    # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
    return multithreaded_generator
Beispiel #3
0
def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
    :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :return: multithreaded_generator
    """

    # create instance of batch generator as first element in pipeline.
    data_gen = BatchGenerator(cf, patient_data, **kwargs)

    my_transforms = []
    if do_aug:
        if cf.da_kwargs["mirror"]:
            mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
            my_transforms.append(mirror_transform)

        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'])

        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))

    my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
    all_transforms = Compose(my_transforms)
    # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads,
                                                     seeds=range(data_gen.n_filled_threads))
    return multithreaded_generator
Beispiel #4
0
def main():

    # assign global args
    global args
    args = parser.parse_args()

    # make a folder for the experiment
    general_folder_name = args.output_path
    try:
        os.mkdir(general_folder_name)
    except OSError:
        pass

    # create train, test split, return the indices, patients in test_split wont be seen during whole training
    train_idx, val_idx, test_idx = CreateTrainValTestSplit(
        HistoFile_path=args.input_file,
        num_splits=args.num_splits,
        num_test_folds=args.num_test_folds,
        num_val_folds=args.num_val_folds,
        seed=args.seed)

    IDs = train_idx + val_idx

    print('size of training set {}'.format(len(train_idx)))
    print('size of validation set {}'.format(len(val_idx)))
    print('size of test set {}'.format(len(test_idx)))

    # data loading
    Data = ProstataData(args.input_file)  #For details on this class see README

    # train and validate
    for cv in range(args.cv_start, args.cv_end):
        best_epoch = 0
        train_loss = []
        val_loss = []

        # define patients for training and validation
        train_idx, val_idx = split_training(IDs,
                                            len_val=62,
                                            cv=cv,
                                            cv_runs=args.cv_number)

        oversampling_factor, Slices_total, Natural_probability_tu_slice, Natural_probability_PRO_slice = get_oversampling(
            Data,
            train_idx=sorted(train_idx),
            Batch_Size=args.b,
            patch_size=args.patch_size)

        training_batches = Slices_total / args.b

        lr = args.lr
        base_lr = args.lr
        args.seed += 1

        print('train_idx', train_idx, len(train_idx))
        print('val_idx', val_idx, len(val_idx))

        # get class frequencies
        print('calculating class frequencie')

        Tumor_frequencie_ADC, Prostate_frequencie_ADC, Background_frequencie_ADC,\
        Tumor_frequencie_T2, Prostate_frequencie_T2, Background_frequencie_T2\
        , ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std \
            = get_class_frequencies(Data, train_idx, patch_size=args.patch_size)

        print ADC_mean, ADC_std, BVAL_mean, BVAL_std, T2_mean, T2_std

        print('ADC', Tumor_frequencie_ADC, Prostate_frequencie_ADC,
              Background_frequencie_ADC)
        print('T2', Tumor_frequencie_T2, Prostate_frequencie_T2,
              Background_frequencie_T2)

        all_ADC = np.float(Background_frequencie_ADC +
                           Prostate_frequencie_ADC + Tumor_frequencie_ADC)
        all_T2 = np.float(Background_frequencie_T2 + Prostate_frequencie_T2 +
                          Tumor_frequencie_T2)

        print all_ADC
        print all_T2

        W1_ADC = 1 / (Background_frequencie_ADC / all_ADC)**0.25
        W2_ADC = 1 / (Prostate_frequencie_ADC / all_ADC)**0.25
        W3_ADC = 1 / (Tumor_frequencie_ADC / all_ADC)**0.25

        Wa_ADC = W1_ADC / (W1_ADC + W2_ADC + W3_ADC)
        Wb_ADC = W2_ADC / (W1_ADC + W2_ADC + W3_ADC)
        Wc_ADC = W3_ADC / (W1_ADC + W2_ADC + W3_ADC)

        print 'Weights ADC', Wa_ADC, Wb_ADC, Wc_ADC

        weight_ADC = (Wa_ADC, Wb_ADC, Wc_ADC)

        W1_T2 = 1 / (Background_frequencie_T2 / all_T2)**0.25
        W2_T2 = 1 / (Prostate_frequencie_T2 / all_T2)**0.25
        W3_T2 = 1 / (Tumor_frequencie_T2 / all_T2)**0.25

        Wa_T2 = W1_T2 / (W1_T2 + W2_T2 + W3_T2)
        Wb_T2 = W2_T2 / (W1_T2 + W2_T2 + W3_T2)
        Wc_T2 = W3_T2 / (W1_T2 + W2_T2 + W3_T2)

        print 'Weights T2', Wa_T2, Wb_T2, Wc_T2

        weight_T2 = (Wa_T2, Wb_T2, Wc_T2)

        # define model
        Net = UNetPytorch(in_shape=(3, args.patch_size[0], args.patch_size[1]))
        Net_Name = 'UNetPytorch'
        model = Net.cuda()

        # model parameter
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr,
                                     weight_decay=args.weight_decay)
        criterion_ADC = CrossEntropyLoss2d(
            weight=torch.FloatTensor(weight_ADC)).cuda()
        criterion_T2 = CrossEntropyLoss2d(
            weight=torch.FloatTensor(weight_T2)).cuda()

        # new folder name for cv
        folder_name = general_folder_name + '/CV_{}'.format(cv)
        try:
            os.mkdir(folder_name)
        except OSError:
            pass

        checkpoint_file = folder_name + '/checkpoint_' + '{}.pth.tar'.format(
            Net_Name)

        # augmentation
        for epoch in range(args.epochs):
            torch.manual_seed(args.seed + epoch + cv)
            np.random.seed(epoch + cv)
            np.random.shuffle(train_idx)

            if epoch == 0:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(100., 450.),
                    sigma=(13., 17.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.75, 1.25),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=True)
                resample_transform = ResampleTransform(zoom_range=(0.7, 1.3))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.75, 1.25), True))
                my_transforms.append(brightness_transform)
                my_transforms.append(Mirror((2, 3)))
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            if epoch == 30:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(0., 250.),
                    sigma=(11., 14.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.85, 1.15),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=True)
                resample_transform = ResampleTransform(zoom_range=(0.8, 1.2))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.85, 1.15), True))
                my_transforms.append(brightness_transform)
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            if epoch == 50:
                my_transforms = []
                spatial_transform = SpatialTransform(
                    args.patch_size,
                    np.array(args.patch_size) // 2,
                    do_elastic_deform=True,
                    alpha=(0., 150.),
                    sigma=(10., 12.),
                    do_rotation=True,
                    angle_z=(-np.pi / 2., np.pi / 2.),
                    do_scale=True,
                    scale=(0.85, 1.15),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    random_crop=False)
                resample_transform = ResampleTransform(zoom_range=(0.9, 1.1))
                brightness_transform = BrightnessTransform(0.0, 0.1, True)
                my_transforms.append(resample_transform)
                my_transforms.append(
                    ContrastAugmentationTransform((0.95, 1.05), True))
                my_transforms.append(brightness_transform)
                all_transforms = Compose(my_transforms)
                sometimes_spatial_transforms = RndTransform(
                    spatial_transform,
                    prob=args.p,
                    alternative_transform=CenterCropTransform(args.patch_size))
                sometimes_other_transforms = RndTransform(all_transforms,
                                                          prob=1.0)
                final_transform = Compose(
                    [sometimes_spatial_transforms, sometimes_other_transforms])
                Center_Crop = CenterCropTransform(args.patch_size)

            train_loader = BatchGenerator(
                Data,
                BATCH_SIZE=args.b,
                split_idx=train_idx,
                seed=args.seed,
                ProbabilityTumorSlices=oversampling_factor,
                epoch=epoch,
                ADC_mean=ADC_mean,
                ADC_std=ADC_std,
                BVAL_mean=BVAL_mean,
                BVAL_std=BVAL_std,
                T2_mean=T2_mean,
                T2_std=T2_std)

            val_loader = BatchGenerator(Data,
                                        BATCH_SIZE=0,
                                        split_idx=val_idx,
                                        seed=args.seed,
                                        ProbabilityTumorSlices=None,
                                        epoch=epoch,
                                        test=True,
                                        ADC_mean=ADC_mean,
                                        ADC_std=ADC_std,
                                        BVAL_mean=BVAL_mean,
                                        BVAL_std=BVAL_std,
                                        T2_mean=T2_mean,
                                        T2_std=T2_std)

            #train on training set
            train_losses = train(train_loader=train_loader,
                                 model=model,
                                 optimizer=optimizer,
                                 criterion_ADC=criterion_ADC,
                                 criterion_T2=criterion_T2,
                                 final_transform=final_transform,
                                 workers=args.workers,
                                 seed=args.seed,
                                 training_batches=training_batches)
            train_loss.append(train_losses)

            # evaluate on validation set
            val_losses = validate(val_loader=val_loader,
                                  model=model,
                                  folder_name=folder_name,
                                  criterion_ADC=criterion_ADC,
                                  criterion_T2=criterion_T2,
                                  split_ixs=val_idx,
                                  epoch=epoch,
                                  workers=1,
                                  Center_Crop=Center_Crop,
                                  seed=args.seed)
            val_loss.append(val_losses)

            # write TrainingsCSV to folder name
            TrainingsCSV = pd.DataFrame({
                'train_loss': train_loss,
                'val_loss': val_loss
            })
            TrainingsCSV.to_csv(folder_name + '/TrainingsCSV.csv')

            if val_losses <= min(val_loss):
                best_epoch = epoch
                print 'best epoch', epoch
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    filename=checkpoint_file)

            optimizer, lr = adjust_learning_rate(optimizer, base_lr, epoch)

        # delete all output except best epoch
        clear_image_data(folder_name, best_epoch, epoch)