示例#1
0
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes,
                                  num_workers=5, num_cached_per_worker=2,
                                  do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                                  do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                  do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None]*num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                       PATCH_SIZE=(352, 352))

    tr_transforms = []
    tr_transforms.append(Mirror((2, 3)))
    tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352))//2),
                                                       do_elastic_transform, alpha,
                                                       sigma,
                                                       do_rotation, a_x, a_y,
                                                       a_z,
                                                       do_scale, scale_range, 'constant', 0, 3, 'constant',
                                                       0, 0,
                                                       random_crop=False), prob=0.67,
                                      alternative_transform=RandomCropTransform((352, 352))))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot'))

    tr_composed = Compose(tr_transforms)
    tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds)
    tr_mt_gen.restart()
    return tr_mt_gen
class MedImageDataSet(object):
    """
       TODO
     """
    def __init__(self,
                 base_dir,
                 mode="train",
                 batch_size=16,
                 num_batches=10000000,
                 seed=None,
                 num_processes=8,
                 num_cached_per_queue=8 * 4,
                 target_size=128,
                 file_pattern='*.png',
                 do_reshuffle=True,
                 keys=None):

        data_loader = MedImageDataLoader(base_dir=base_dir,
                                         mode=mode,
                                         batch_size=batch_size,
                                         num_batches=num_batches,
                                         seed=seed,
                                         file_pattern=file_pattern,
                                         keys=keys)

        self.data_loader = data_loader
        self.batch_size = batch_size
        #self.do_reshuffle = do_reshuffle
        self.number_of_slices = 1

        self.transforms = get_transforms(mode=mode, target_size=target_size)
        self.augmenter = MultiThreadedAugmenter(
            data_loader,
            self.transforms,
            num_processes=num_processes,
            num_cached_per_queue=num_cached_per_queue,
            seeds=seed,
            shuffle=do_reshuffle)
        self.augmenter.restart()

    def __len__(self):
        return len(self.data_loader)

    def __iter__(self):
        self.augmenter.renew()
        return self.augmenter

    def __next__(self):
        return next(self.augmenter)
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    :param dataloader_train:
    :param dataloader_val:
    :param patch_size:
    :param params:
    :param border_val_seg:
    :return:
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"),
                                                  seeds=range(params.get('num_threads')), pin_memory=True)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1),
                                                params.get("num_cached_per_thread"),
                                                seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
示例#4
0
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5),
                          gamma_range = (0.6, 2),
                                  num_workers=5, num_cached_per_worker=3,
                                  do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                                  do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                  do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None]*num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                                          patch_size=(160, 192, 160), convert_labels=True)
    tr_transforms = []
    tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3]))
    tr_transforms.append(GenerateBrainMaskTransform())
    tr_transforms.append(MirrorTransform())
    tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.),
                                       do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma,
                                       do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z,
                                       do_scale=do_scale, scale=scale_range, border_mode_data='nearest',
                                       border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0,
                                       order_seg=0, random_crop=True))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True))
    tr_transforms.append(ContrastAugmentationTransform(contrast_range, True))
    tr_transforms.append(GammaTransform(gamma_range, False))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True))
    tr_transforms.append(BrightnessTransform(0.0, 0.1, True))
    tr_transforms.append(SegChannelSelectionTransform([0]))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot"))

    gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker,
                                       seeds)
    gen_train.restart()
    return gen_train
示例#5
0
def main():
    # --------- Parse arguments ---------------------------------------------------------------
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='./config_0.yaml',
                        help='Path to the configuration file.')
    parser.add_argument('--dice', action='store_true')
    # be aware of this argument!!!
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('-i', '--inference', default='', type=str, metavar='PATH',
                        help='run inference on data set and save results')

    # 1e-8 works well for lung masks but seems to prevent
    # rapid learning for nodule masks
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=1)
    args = parser.parse_args()
    # ---------- get the config file(config.yaml) --------------------------------------------
    config = get_config(args.config)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = os.path.join('./work', (datestr() + ' ' + config['filename']))
    nll = True
    if config['dice']:
        nll = False

    weight_decay = config['weight_decay']
    num_threads_for_kits19 = config['num_of_threads']
    patch_size = (160, 160, 128)
    num_batch_per_epoch = config['num_batch_per_epoch']
    setproctitle.setproctitle(args.save)
    start_epoch = 1
    # -------- Record best kidney segmentation dice -------------------------------------------
    best_tk = 0.0
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    # Embed attention module
    model = vnet.VNet(elu=False, nll=nll,
                      attention=config['attention'], nclass=3)  # mark
    batch_size = config['ngpu'] * config['batchSz']
    save_iter = config['model_save_iter']
    # batch_size = args.ngpu*args.batchSz
    gpu_ids = range(config['ngpu'])
    # print(gpu_ids)
    model.apply(weights_init)
    # ------- Resume training from saved model -----------------------------------------------
    if config['resume']:
        if os.path.isfile(config['resume']):
            print("=> loading checkpoint '{}'".format(config['resume']))
            checkpoint = torch.load(config['resume'])
            # .tar files
            if config['resume'].endswith('.tar'):
                # print(checkpoint, "tar")
                start_epoch = checkpoint['epoch']
                best_tk = checkpoint['best_tk']
                checkpoint_model = checkpoint['model_state_dict']
                model.load_state_dict(
                    {k.replace('module.', ''): v for k, v in checkpoint_model.items()})
            # .pkl files for the whole model
            else:
                # print(checkpoint, "pkl")
                model.load_state_dict(checkpoint.state_dict())
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config['resume']))
            exit(-1)
    else:
        pass
    # ------- Which loss function to use ------------------------------------------------------
    if nll:
        training = train_bg
        validate = test_bg
        # class_balance = True
    else:
        training = train_bg_dice
        validate = test_bg_dice
        # class_balance = False
    # -----------------------------------------------------------------------------------------
    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    # -------- Set on GPU ---------------------------------------------------------------------
    if args.cuda:
        model = model.cuda()

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    # create the output directory
    os.makedirs(args.save)
    # save the config file to the output folder
    shutil.copy(args.config, os.path.join(args.save, 'config.yaml'))

    # kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    # ------ Load Training and Validation set --------------------------------------------
    preprocessed_folders = "/home/data_share/npy_data/"
    patients = get_list_of_patients(
        preprocessed_data_folder=preprocessed_folders)
    # split num_split cross-validation sets
    # train, val = get_split_deterministic(
    #     patients, fold=0, num_splits=5, random_state=12345)
    train, val = patients[0:147], patients[147:189]

    # VALIDATION DATA CANNOT BE LOADED IN CASE DUE TO THE LARGE SHAPE...
    # PRINT VALIDATION CASES FOR LATER TEST USE!!
    print("Validation cases:\n", val)
    # set max shape for validation set 
    shapes = [Kits2019DataLoader3D.load_patient(
        i)[0].shape[1:] for i in val]
    max_shape = np.max(shapes, 0)
    max_shape = np.max((max_shape, patch_size), 0)
    # data loading + augmentation
    dataloader_train = Kits2019DataLoader3D(
        train, batch_size, patch_size, num_threads_for_kits19)
    dataloader_validation = Kits2019DataLoader3D(
        val, batch_size * 2, patch_size, num_threads_for_kits19)
    tr_transforms = get_train_transform(patch_size, prob=config['prob'])
    # whether to use single/multiThreadedAugmenter ------------------------------------------
    if num_threads_for_kits19 > 1:
        tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, 
                                        num_processes=num_threads_for_kits19,
                                        num_cached_per_queue=3,seeds=None, pin_memory=True)
        val_gen = MultiThreadedAugmenter(dataloader_validation, None,
                                         num_processes=max(1, num_threads_for_kits19//2), 
                                         num_cached_per_queue=1, seeds=None, pin_memory=False)
        
        tr_gen.restart()
        val_gen.restart()
    else:
        tr_gen = SingleThreadedAugmenter(dataloader_train, transform=tr_transforms)
        val_gen = SingleThreadedAugmenter(dataloader_validation, transform=None)
    # ------- Set learning rate scheduler ----------------------------------------------------
    lr_schdl = lr_scheduler.LR_Scheduler(mode=config['lr_policy'], base_lr=config['lr'],
                                         num_epochs=config['nEpochs'], iters_per_epoch=num_batch_per_epoch,
                                         lr_step=config['step_size'], warmup_epochs=config['warmup_epochs'])
    
    # ------ Choose Optimizer ----------------------------------------------------------------
    if config['opt'] == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=config['lr'],
                              momentum=0.99, weight_decay=weight_decay)
    elif config['opt'] == 'adam':
        optimizer = optim.Adam(
            model.parameters(), lr=config['lr'], weight_decay=weight_decay)
    elif config['opt'] == 'rmsprop':
        optimizer = optim.RMSprop(
            model.parameters(), lr=config['lr'], weight_decay=weight_decay)
    lr_plateu = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, verbose=True, threshold=1e-3, patience=5)
    # ------- Apex Mixed Precision Acceleration ----------------------------------------------
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    model = nn.parallel.DataParallel(model, device_ids=gpu_ids)
    # ------- Save training data -------------------------------------------------------------
    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    trainF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')
    testF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n ')
    # ------- Training Pipeline --------------------------------------------------------------
    for epoch in range(start_epoch, config['nEpochs'] + start_epoch):
        torch.cuda.empty_cache()
        training(args, epoch, model, tr_gen, optimizer, trainF, config, lr_schdl)
        torch.cuda.empty_cache()
        print('==>lr decay to:', optimizer.param_groups[0]['lr'])
        print('testing validation set...')
        composite_dice = validate(args, epoch, model, val_gen, optimizer, testF, config, lr_plateu)
        torch.cuda.empty_cache()
        # save model with best result and routinely
        if composite_dice > best_tk or epoch % config['model_save_iter'] == 0:
            # model_name = 'vnet_epoch_step1_' + str(epoch) + '.pkl'
            model_name = 'vnet_step1_' + str(epoch) + '.tar'
            # torch.save(model, os.path.join(args.save, model_name))
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_tk': best_tk
            }, os.path.join(args.save, model_name))
            best_tk = composite_dice
    # ----------------------------------------------------------------------------------------
    trainF.close()
    testF.close()
def get_no_augmentation(dataloader_train,
                        dataloader_val,
                        params=default_3D_augmentation_params,
                        deep_supervision_scales=None,
                        soft_ds=False,
                        classes=None,
                        pin_memory=True,
                        regions=None):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=range(params.get('num_threads')),
        pin_memory=pin_memory)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=range(max(params.get('num_threads') // 2, 1)),
        pin_memory=pin_memory)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
示例#7
0
        num_cached_per_queue=3,
        seeds=None,
        pin_memory=False)
    # we need less processes for vlaidation because we dont apply transformations
    val_gen = MultiThreadedAugmenter(dataloader_validation,
                                     None,
                                     num_processes=max(
                                         1,
                                         num_threads_for_brats_example // 2),
                                     num_cached_per_queue=1,
                                     seeds=None,
                                     pin_memory=False)

    # lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training
    # batches while other things run in the main thread
    tr_gen.restart()
    val_gen.restart()

    # now if this was a network training you would run epochs like this (remember tr_gen and val_gen generate
    # inifinite examples! Don't do "for batch in tr_gen:"!!!):
    num_batches_per_epoch = 10
    num_validation_batches_per_epoch = 3
    num_epochs = 5
    # let's run this to get a time on how long it takes
    time_per_epoch = []
    start = time()
    for epoch in range(num_epochs):
        start_epoch = time()
        for b in range(num_batches_per_epoch):
            batch = next(tr_gen)
            # do network training here with this batch
示例#8
0
    val=get_file_list(brats_preprocessed_folder,valid_ids_path)

    shapes = [brats_dataloader.load_patient(i)[0].shape[1:] for i in train]
    max_shape = np.max(shapes, 0)
    max_shape = list(np.max((max_shape, patch_size), 0))

    dataloader_train = brats_dataloader(train, batch_size, max_shape, num_threads,return_incomplete=True)
    dataloader_validation = brats_dataloader(val,batch_size, None,1,infinite=False,shuffle=False,return_incomplete=True)

    tr_transforms = get_train_transform(patch_size)

    tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads,
                                    num_cached_per_queue=3,
                                    seeds=None, pin_memory=False)

    tr_gen.restart()



    log=open(log_path,'w')
    log.write('epoch,loss,valid loss\n')
    min_loss=1000

    num_batches_per_epoch=int(math.ceil(len(train)/batch_size))
    num_validation_batches_per_epoch=int(math.ceil(len(val)/batch_size))

    current_lr=lr

    for epoch in range(max_epoch):
        raw_loss=0
        with trange(num_batches_per_epoch) as t:
示例#9
0
def get_arteries_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True,
                              regions=None,
                              use_nondetMultiThreadedAugmenter: bool = False):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    # if params.get("dummy_2D") is not None and params.get("dummy_2D"):
    #     ignore_axes = (0,)
    #     tr_transforms.append(Convert3DTo2DTransform())
    # else:
    #     ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=False,
                         do_rotation=False,
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=False,
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
    #     tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
    #     if params.get("cascade_do_cascade_augmentations") is not None and params.get(
    #             "cascade_do_cascade_augmentations"):
    #         if params.get("cascade_random_binary_transform_p") > 0:
    #             tr_transforms.append(ApplyRandomBinaryOperatorTransform(
    #                 channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
    #                 p_per_sample=params.get("cascade_random_binary_transform_p"),
    #                 key="data",
    #                 strel_size=params.get("cascade_random_binary_transform_size"),
    #                 p_per_label=params.get("cascade_random_binary_transform_p_per_label")))
    #         if params.get("cascade_remove_conn_comp_p") > 0:
    #             tr_transforms.append(
    #                 RemoveRandomConnectedComponentFromOneHotEncodingTransform(
    #                     channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
    #                     key="data",
    #                     p_per_sample=params.get("cascade_remove_conn_comp_p"),
    #                     fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
    #                     dont_do_if_covers_more_than_X_percent=params.get(
    #                         "cascade_remove_conn_comp_fill_with_other_class_p")))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # if deep_supervision_scales is not None:
    #     if soft_ds:
    #         assert classes is not None
    #         tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
    #     else:
    #         tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
    #                                                           output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
    #     val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # if deep_supervision_scales is not None:
    #     if soft_ds:
    #         assert classes is not None
    #         val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
    #     else:
    #         val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target',
    #                                                            output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val.restart()

    return batchgenerator_train, batchgenerator_val
def create_data_gen_train(patient_data_train,
                          BATCH_SIZE,
                          num_classes,
                          patch_size,
                          num_workers=5,
                          num_cached_per_worker=2,
                          do_elastic_transform=False,
                          alpha=(0., 1300.),
                          sigma=(10., 13.),
                          do_rotation=False,
                          a_x=(0., 2 * np.pi),
                          a_y=(0., 2 * np.pi),
                          a_z=(0., 2 * np.pi),
                          do_scale=True,
                          scale_range=(0.75, 1.25),
                          seeds=None):
    if seeds is None:
        seeds = [None] * num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator(patient_data_train,
                                    BATCH_SIZE,
                                    num_batches=None,
                                    seed=False,
                                    PATCH_SIZE=(10, 352, 352))

    # train transforms
    tr_transforms = []
    tr_transforms.append(MotionAugmentationTransform(0.1, 0, 20))
    tr_transforms.append(MirrorTransform((3, 4)))
    tr_transforms.append(Convert3DTo2DTransform())
    tr_transforms.append(
        RndTransform(SpatialTransform(patch_size[1:],
                                      112,
                                      do_elastic_transform,
                                      alpha,
                                      sigma,
                                      do_rotation,
                                      a_x,
                                      a_y,
                                      a_z,
                                      do_scale,
                                      scale_range,
                                      'constant',
                                      0,
                                      3,
                                      'constant',
                                      0,
                                      0,
                                      random_crop=False),
                     prob=0.67,
                     alternative_transform=RandomCropTransform(
                         patch_size[1:])))
    tr_transforms.append(Convert2DTo3DTransform(patch_size))
    tr_transforms.append(
        RndTransform(GammaTransform((0.85, 1.3), False), prob=0.5))
    tr_transforms.append(
        RndTransform(GammaTransform((0.85, 1.3), True), prob=0.5))
    tr_transforms.append(CutOffOutliersTransform(0.3, 99.7, True))
    tr_transforms.append(ZeroMeanUnitVarianceTransform(True))
    tr_transforms.append(
        ConvertSegToOnehotTransform(range(num_classes), 0, 'seg_onehot'))

    tr_composed = Compose(tr_transforms)
    tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed,
                                       num_workers, num_cached_per_worker,
                                       seeds)
    tr_mt_gen.restart()
    return tr_mt_gen
示例#11
0
def get_default_augmentation(dataloader_train, dataloader_val=None, params=None,
                             patch_size=None, border_val_seg=-1, pin_memory=True,
                             seeds_train=None, seeds_val=None, regions=None):
    assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
    tr_transforms = []

    assert params is not None, "augmentation params expect to be not None"

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(SpatialTransform(
        patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
        alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
        do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
        angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
        border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
        border_cval_seg=border_val_seg,
        order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
        p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
        independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
    ))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"), seeds=seeds_train,
                                                  pin_memory=pin_memory)
    batchgenerator_train.restart()

    if dataloader_val is None:
        return batchgenerator_train, None

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
                                                params.get("num_cached_per_thread"), seeds=seeds_val,
                                                pin_memory=pin_memory)
    batchgenerator_val.restart()

    return batchgenerator_train, batchgenerator_val
示例#12
0
def main():
    # init log
    now = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time()))
    log_path = train_params['log_path']
    if not os.path.isdir(log_path):
        os.makedirs(log_path)
    log = open(os.path.join(log_path, 'log_{}.txt'.format(now)), 'w')
    print_log('save path : {}'.format(log_path), log)

    # prepare dataset
    dataset = load_train_dataset(phase='train',
                                 data_list_path=file_paths['train_list'])
    da_dataset = load_train_dataset_d(phase='train',
                                      data_list_path=file_paths['test_list'])
    # augmentation
    aug_transforms = get_train_transform(model_params['patch_size'])
    # source domain
    src_data_gen = ISeg2019DataLoader3D(dataset,
                                        model_params['batch_size'],
                                        model_params['patch_size'],
                                        nb_modalities=2,
                                        num_threads_in_multithreaded=4)
    src_aug_gen = MultiThreadedAugmenter(src_data_gen,
                                         aug_transforms,
                                         num_processes=4,
                                         num_cached_per_queue=4)
    src_aug_gen.restart()
    # target domain
    tgt_data_gen = ISeg2019DataLoader3D_Unlabel(da_dataset,
                                                model_params['batch_size'],
                                                model_params['patch_size'],
                                                nb_modalities=2,
                                                num_threads_in_multithreaded=4)
    tgt_aug_gen = MultiThreadedAugmenter(tgt_data_gen,
                                         aug_transforms,
                                         num_processes=4,
                                         num_cached_per_queue=4)
    tgt_aug_gen.restart()

    # define network
    net = DenseNet_3D(ver=model_params['model_ver'])
    net_d = FCDiscriminator3D(num_classes=model_params['nb_classes'], ndf=32)

    # define loss
    seg_loss = torch.nn.CrossEntropyLoss()
    bce_loss = torch.nn.BCEWithLogitsLoss()

    # define optimizer
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=train_params['lr_rate'],
                                 weight_decay=train_params['weight_decay'],
                                 betas=(train_params['momentum'], 0.999))
    optimizer.zero_grad()
    optimizer_d = torch.optim.Adam(net_d.parameters(),
                                   lr=train_params['lr_rate_d'],
                                   weight_decay=train_params['weight_decay'],
                                   betas=(0.9, 0.99))
    optimizer_d.zero_grad()

    start_step = 0
    best_dice = 0.
    if train_params['resume_path'] is not None:
        print_log("=======> loading checkpoint '{}'".format(
            train_params['resume_path']),
                  log=log)
        checkpoint = torch.load(train_params['resume_path'])
        net.load_state_dict(checkpoint['model_state_dict'])
        print_log("=======> loaded checkpoint '{}'".format(
            train_params['resume_path']),
                  log=log)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                train_params['lr_step_size'],
                                                train_params['lr_gamma'])
    scheduler_d = torch.optim.lr_scheduler.StepLR(optimizer_d,
                                                  train_params['lr_step_size'],
                                                  train_params['lr_gamma'])

    # start training
    net.cuda()
    net_d.cuda()
    seg_loss.cuda()
    bce_loss.cuda()
    net.train()
    net_d.train()
    source_label = 1
    target_label = 0
    for step in range(start_step, train_params['nb_iters']):
        loss_seg_value = 0.
        loss_adv_target_value = 0.
        loss_D_src_value = 0.
        loss_D_tgt_value = 0.

        optimizer.zero_grad()
        optimizer_d.zero_grad()

        for sub_iter in range(train_params['nb_accu_iters']):
            # train G
            for param in net_d.parameters():
                param.requires_grad = False

            # train with source
            src_batch = next(src_aug_gen)
            src_input_img = torch.from_numpy(src_batch['data']).cuda()
            src_input_label = torch.from_numpy(
                np.squeeze(src_batch['seg'], axis=1).astype(np.int64)).cuda()

            src_seg_out = net(src_input_img)
            loss = seg_loss(src_seg_out, src_input_label)
            loss_seg = loss / train_params['nb_accu_iters']
            loss_seg.backward()

            loss_seg_value += loss.data.cpu().numpy()

            # train with target
            tgt_batch = next(tgt_aug_gen)
            tgt_input_img = torch.from_numpy(tgt_batch['data']).cuda()

            tgt_seg_out = net(tgt_input_img)
            tgt_d_out = net_d(prob_2_entropy(F.softmax(tgt_seg_out, dim=1)))
            loss = bce_loss(
                tgt_d_out,
                Variable(
                    torch.FloatTensor(
                        tgt_d_out.data.size()).fill_(source_label)).cuda())
            loss_adv_tgt = train_params[
                'lambda_adv_target'] * loss / train_params['nb_accu_iters']
            loss_adv_tgt.backward()

            loss_adv_target_value += loss.data.cpu().numpy()

            # train D
            for param in net_d.parameters():
                param.requires_grad = True

            # train with source
            src_seg_out = src_seg_out.detach()
            src_d_out = net_d(prob_2_entropy(F.softmax(src_seg_out, dim=1)))
            loss = bce_loss(
                src_d_out,
                Variable(
                    torch.FloatTensor(
                        src_d_out.data.size()).fill_(source_label)).cuda())
            loss_d_src = loss / train_params['nb_accu_iters']
            loss_d_src.backward()

            loss_D_src_value += loss.data.cpu().numpy()

            # train with target
            tgt_seg_out = tgt_seg_out.detach()
            tgt_d_out = net_d(prob_2_entropy(F.softmax(tgt_seg_out, dim=1)))
            loss = bce_loss(
                tgt_d_out,
                Variable(
                    torch.FloatTensor(
                        tgt_d_out.data.size()).fill_(target_label)).cuda())
            loss_d_tgt = loss / train_params['nb_accu_iters']
            loss_d_tgt.backward()

            loss_D_tgt_value += loss.data.cpu().numpy()

        optimizer.step()
        scheduler.step()
        optimizer_d.step()
        scheduler_d.step()

        log_str = 'step {}: lr:{:.8f}, lr_d:{:.8f}, loss_seg:{:.6f}, loss_adv:{:.6f}, loss_D_src:{:.6f}, loss_D_tgt:{:.6f}'\
            .format(step, scheduler.get_lr()[0], scheduler_d.get_lr()[0], loss_seg_value, loss_adv_target_value, loss_D_src_value, loss_D_tgt_value)
        print_log(log_str, log)

        # val and save per N iterations
        if (step + 1) % train_params['snapshot_step_size'] == 0:
            net.eval()
            val_avg_dice1, val_avg_dice2, val_avg_dice3, val_avg_dice = validation(
                net, dataset)
            val_log_str = 'val step: val_avg_dice:{}, val_avg_dice1:{}, val_avg_dice2:{}, val_avg_dice3:{}' \
                .format(val_avg_dice, val_avg_dice1, val_avg_dice2, val_avg_dice3)
            print_log(val_log_str, log)

            is_best = False
            if val_avg_dice > best_dice:
                best_dice = val_avg_dice
                is_best = True

            save_checkpoint({
                'model_state_dict': net.state_dict(),
            }, is_best, train_params['model_snapshot_path'],
                            'checkpoint-{}.pth'.format(step + 1),
                            'model_best.pth')
            save_checkpoint({
                'model_d': net_d.state_dict(),
            }, is_best, train_params['model_snapshot_path'],
                            'checkpoint-d-{}.pth'.format(step + 1),
                            'model_d_best.pth')

            net.train()
    log.close()