Beispiel #1
0
def train(args):
    """

    :param exper_hdl:
    :return:
    """

    torch.manual_seed(5431232439)
    torch.cuda.manual_seed(5431232439)
    torch.backends.cudnn.enabled = True
    np.random.seed(6572345)
    # get fold from segmentation experiment settings. We train detection model per cross-validation fold
    seg_exper_settings = loadExperimentSettings(
        os.path.join(args.src_path_data, 'settings.yaml'))
    args.fold = seg_exper_settings.fold
    print("WARNING - processing fold {}".format(args.fold))
    config_detector.get_architecture(args.network)
    # set number of input channels for initialization of model
    if args.input_channels != "allchannels":
        if args.input_channels == "mronly":
            config_detector.architecture['n_channels_input'] = 1
        else:
            config_detector.architecture['n_channels_input'] = 2
        print("WARNING - Using {} channels as input".format(
            config_detector.architecture['n_channels_input']))
    if args.fn_penalty_weight is not None:
        config_detector.architecture[
            "fn_penalty_weight"] = args.fn_penalty_weight
        print("WARNING - Using args fn_penalty_weight {:.3f}".format(
            config_detector.architecture["fn_penalty_weight"]))
    if args.output_directory is None:
        # synthesize
        output_dir = synthesize_output_dir(args, config_detector.architecture)
        args.output_directory = os.path.join(
            os.path.join(args.src_path_data, "dt_logs"), output_dir)
    else:
        args.output_directory = os.path.expanduser(args.output_directory)
    os.makedirs(args.output_directory, exist_ok=False)
    saveExperimentSettings(
        args, os.path.join(args.output_directory, 'settings.yaml'))
    saveExperimentSettings(
        config_detector.architecture,
        os.path.join(args.output_directory, 'architecture.yaml'))
    print(args)

    # get dataset
    dataset = create_dataset(args.fold,
                             args.src_path_data,
                             mc_dropout=args.mc_dropout,
                             num_of_input_chnls=3,
                             limited_load=args.limited_load,
                             dt_config_id=args.dt_config_id,
                             cardiac_phases=tuple(('ES', 'ED')))
    # and finally we initialize something for visualization in visdom
    seg_model = args.src_path_data.split("/")[-1]
    dt_log_dir = args.output_directory.split("/")[-1]
    env = 'Detection{}-{}-{}_{}'.format(args.dataset,
                                        seg_model.replace("_", '-'),
                                        args.input_channels, dt_log_dir)
    vis = Visualizer(env, args.port,
                     'Learning curves of fold {}'.format(args.fold),
                     ['training', 'validation'])
    vis_metrics = Visualizer(
        env, args.port,
        'Grid detection prec/rec metrics fold {}'.format(args.fold),
        ['precision', 'recall', 'pr_auc'])
    vis_detection_rate = Visualizer(
        env, args.port, 'Slice/voxel detection rate fold {}'.format(args.fold),
        ['detection_rate', 'slice_tp_rate', 'slice_tn_rate'])
    do_balance_batch = True
    trainer = get_trainer(args, config_detector.architecture, model_file=None)
    try:
        for _ in tqdm(range(args.max_iters),
                      desc="Train {}".format(args.network)):
            # store model
            if not trainer._train_iter % args.store_model_every:
                trainer.save(args.output_directory)

            # store learning curves
            if not trainer._train_iter % args.store_curves_every:
                trainer.save_losses(args.output_directory)

            # visualize example from validation set
            if not trainer._train_iter % args.update_visualizer_every and trainer._train_iter > 0:
                vis(trainer.current_training_loss,
                    trainer.current_validation_loss)  # plot learning curve

            train_batch = BatchHandler(data_set=dataset,
                                       is_train=True,
                                       verbose=False,
                                       keep_bounding_boxes=False,
                                       input_channels=args.input_channels,
                                       num_of_max_pool_layers=config_detector.
                                       architecture['num_of_max_pool'],
                                       app_config=config_detector)
            x_input, ref_labels = train_batch(batch_size=args.batch_size,
                                              do_balance=do_balance_batch)
            y_labels = ref_labels[config_detector.max_grid_spacing]
            trainer.train(x_input,
                          y_labels,
                          y_labels_seg=train_batch.batch_labels_per_voxel)
            if not trainer._train_iter % args.update_visualizer_every and trainer._train_iter > 0:
                val_batch = BatchHandler(
                    data_set=dataset,
                    is_train=False,
                    verbose=False,
                    keep_bounding_boxes=False,
                    input_channels=args.input_channels,
                    num_of_max_pool_layers=config_detector.
                    architecture['num_of_max_pool'],
                    app_config=config_detector)
                val_set_size = dataset.get_size(is_train=False)
                val_batch.last_test_list_idx = np.random.randint(0,
                                                                 val_set_size -
                                                                 101,
                                                                 size=1)
                trainer.evaluate(val_batch, keep_batch=True)
                vis_metrics(trainer.validation_metrics['prec'],
                            trainer.validation_metrics['rec'],
                            trainer.validation_metrics['pr_auc'])
                dt_rate = trainer.validation_metrics[
                    'detected_voxel_count'] / trainer.validation_metrics[
                        'total_voxel_count']
                dt_slice_tp = trainer.validation_metrics['tp_slice'] / (trainer.validation_metrics['tp_slice'] + \
                              trainer.validation_metrics['fn_slice'])
                dt_slice_tn = trainer.validation_metrics['tn_slice'] / (trainer.validation_metrics['tn_slice'] + \
                              trainer.validation_metrics['fp_slice'])
                vis_detection_rate(dt_rate, dt_slice_tp, dt_slice_tn)
                idx = 12
                patid = val_batch.batch_patient_slice_id[idx][0]
                val_img = val_batch.keep_batch_images[idx][0][0]
                w, h, = val_img.shape
                vis.image((val_img**.5), 'image {}'.format(patid), 11)
                vis.image((val_batch.keep_batch_images[idx][0][1] / 0.9),
                          'uncertainty {}'.format(patid), 12)
                vis.image(val_batch.keep_batch_label_slices[idx] / 1.001,
                          'reference', 13)
                vis.image((val_batch.keep_batch_images[idx][0][2] / 1.001),
                          'seg mask', 16)
                p = np.squeeze(val_batch.batch_pred_probs[idx])[1]
                heat_map, grid_map, target_lbl_grid = create_grid_heat_map(
                    p,
                    config_detector.max_grid_spacing,
                    w,
                    h,
                    prob_threshold=0.5)
                vis.image((heat_map**.5), 'grid predictions', 14)
                if args.network == "rsnup":
                    p_mask = np.argmax(np.squeeze(trainer.val_segs[idx]),
                                       axis=0)
                    vis.image(p_mask / 1.001, 'predictions', 15)

                del val_batch

    except KeyboardInterrupt:
        print('interrupted')

    finally:
        trainer.save(args.output_directory)
        trainer.save_losses(args.output_directory)
def main():
    # first we obtain the user arguments, set random seeds, make directories, and store the experiment settings.
    args = parse_args()
    # Set resample always to True for ACDC
    args = get_network_settings(args)
    # End - overwriting args
    args.patch_size = tuple(args.patch_size)
    torch.manual_seed(5431232439)
    torch.cuda.manual_seed(5431232439)
    rs = np.random.RandomState(78346)
    os.makedirs(args.output_directory, exist_ok=True)
    saveExperimentSettings(args,
                           path.join(args.output_directory, 'settings.yaml'))
    print(args)
    dta_settings = get_config(args.dataset)

    # we create a trainer
    n_classes = len(dta_settings.tissue_structure_labels)
    n_channels_input = 1

    trainer, pad = get_trainer(args, n_classes, n_channels_input)

    # we initialize datasets with augmentations.
    training_augmentations = get_train_augmentations(args, rs, pad)
    validation_augmentations = [
        datasets.augmentations.PadInput(pad, args.patch_size),
        datasets.augmentations.RandomCrop(args.patch_size,
                                          input_padding=pad,
                                          rs=rs),
        datasets.augmentations.BlurImage(sigma=0.9),
        datasets.augmentations.ToTensor()
    ]

    training_set, validation_set = get_datasets(
        args, dta_settings, transforms.Compose(training_augmentations),
        transforms.Compose(validation_augmentations))

    # now we create dataloaders
    tra_sampler = RandomSampler(training_set,
                                replacement=True,
                                num_samples=args.batch_size * args.max_iters)
    val_sampler = RandomSampler(validation_set,
                                replacement=True,
                                num_samples=args.batch_size * args.max_iters)

    data_loader_training = torch.utils.data.DataLoader(
        training_set,
        batch_size=args.batch_size,
        sampler=tra_sampler,
        num_workers=args.number_of_workers,
        collate_fn=None)  # _utils.collate.default_collate

    data_loader_validation = torch.utils.data.DataLoader(
        validation_set,
        batch_size=args.batch_size,
        sampler=val_sampler,
        num_workers=args.number_of_workers,
        collate_fn=None)

    # and finally we initialize something for visualization in visdom
    env_suffix = "f" + str(args.fold) + args.output_directory.split("_")[-1]
    vis = Visualizer(
        'Segmentation{}-{}_{}'.format(args.dataset, args.network, env_suffix),
        args.port, 'Learning curves of fold {}'.format(args.fold),
        ['training', 'validation', 'aleatoric'])
    #
    try:
        for it, (training_batch, validation_batch) in tqdm(
                enumerate(zip(data_loader_training, data_loader_validation)),
                desc='Training',
                total=args.max_iters):

            # store model
            if not trainer._train_iter % args.store_model_every:
                trainer.save(args.output_directory)

            # store learning curves
            if not trainer._train_iter % args.store_curves_every:
                trainer.save_losses(args.output_directory)

                # visualize example from validation set
                if not trainer._train_iter % args.update_visualizer_every and trainer._train_iter > 20:
                    image = validation_batch['image'][0][None]
                    val_output = trainer.predict(image)
                    prediction = val_output['predictions']
                    reference = validation_batch['reference'][0]
                    val_patient_id = validation_batch['patient_id'][0]

                    image = image.detach().numpy()
                    prediction = prediction.detach().numpy().astype(
                        float)  # .transpose(1, 2, 0)
                    reference = reference.detach().numpy().astype(float)
                    if pad > 0:
                        # Note: image has shape [batch, 1, x, y], we get rid off extra padding in last two dimensions
                        vis.image((image[0, 0, pad:-pad, pad:-pad]**.5),
                                  'padded image {}'.format(val_patient_id), 12)
                    else:
                        vis.image((image[0]**.5),
                                  'image {}'.format(val_patient_id), 11)
                    vis.image(reference / 3, 'reference', 13)
                    vis.image(prediction / 3, 'prediction',
                              14)  # used log_softmax values
                    if 'aleatoric' in val_output.keys():
                        vis.image(val_output['aleatoric'] / 0.9, 'aleatoric',
                                  15)  #
                    # vis.image((prediction >= 0.5).astype(float), 'binary prediction', 15)
                    # visualize learning curve
                    vis(trainer.current_training_loss,
                        trainer.current_validation_loss,
                        trainer.current_aleatoric_loss)  # plot learning curve

            # train on training mini-batch
            trainer.train(training_batch['image'].to(device),
                          training_batch['reference'].to(device),
                          ignore_label=None
                          if 'ignore_label' not in training_batch.keys() else
                          training_batch['ignore_label'])
            # evaluate on validation mini-batch
            trainer.evaluate(validation_batch['image'].to(device),
                             validation_batch['reference'].to(device),
                             ignore_label=None
                             if 'ignore_label' not in validation_batch.keys()
                             else validation_batch['ignore_label'])

    except KeyboardInterrupt:
        print('interrupted')

    finally:
        trainer.save(args.output_directory)
        trainer.save_losses(args.output_directory)