Exemple #1
0
def train_nodule_segmentation_no_augmentation_normalization_dice(
    dataset_file,
    output_weights_file,
    batch_size=5,
    num_epochs=10,
    last_epoch=0,
    initial_weights=None,
):
    """Train the network from scratch or from a preexisting set of weights on the dataset"""

    # Loaders
    dataset = h5py.File(dataset_file, "r")
    df = loader.dataset_metadata_as_dataframe(dataset,
                                              key='nodule_masks_spherical')
    df_training = df[df.subset.isin([0, 1, 2, 3, 4, 5, 6, 7]) & df.has_mask]
    dataset.close()
    training_loader = loader.NoduleSegmentationSequence(dataset_file,
                                                        batch_size,
                                                        dataframe=df_training,
                                                        epoch_frac=1.0,
                                                        epoch_shuffle=False)
    df_validation = df[df.subset.isin([8]) & df.has_mask]
    validation_loader = loader.NoduleSegmentationSequence(
        dataset_file,
        batch_size,
        dataframe=df_validation,
        epoch_frac=1.0,
        epoch_shuffle=False)

    # Callbacks
    model_checkpoint = ModelCheckpoint(output_weights_file,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True)
    early_stopping = EarlyStopping(monitor='val_loss', patience=10)
    history_log = HistoryLog(output_weights_file + ".history")

    # Setup network
    network_size = [*DEFAULT_UNET_SIZE, 1, dice_coef_loss]
    model = Unet(*network_size)

    if initial_weights:
        model.load_weights(initial_weights)

    # Train
    model.fit_generator(
        generator=training_loader,
        epochs=num_epochs,
        initial_epoch=last_epoch,
        verbose=1,
        validation_data=validation_loader,
        use_multiprocessing=True,
        workers=4,
        max_queue_size=20,
        shuffle=True,
        callbacks=[model_checkpoint, early_stopping, history_log])
def main():
    parser = argparse.ArgumentParser(
        description='Evaluate CT nodule scan segmentation for a subset')
    parser.add_argument('dataset',
                        type=str,
                        help="Path to the hdf5 with the equalized spaced data")
    parser.add_argument('csv_annotations',
                        type=str,
                        help="CSV with real annotations")
    parser.add_argument('model_weights',
                        type=str,
                        help="path where the model weights are stored")
    parser.add_argument('output',
                        type=str,
                        help="path where to store the detailed output")
    parser.add_argument(
        'subsets',
        type=int,
        nargs='+',
        help="subset for which you want evaluate the segmentation")
    parser.add_argument('--batch-size',
                        dest='batch_size',
                        type=int,
                        default=5,
                        action="store",
                        help="evaluation batch size")
    parser.add_argument('--no-normalization',
                        dest='batch_normalization',
                        action='store_false')
    parser.add_argument('--loss-binary-crossentropy',
                        dest='loss_binary_crossentropy',
                        action='store_true')
    parser.add_argument('--laplacian',
                        dest='use_laplacian',
                        action='store_true')
    parser.add_argument('--mask-type',
                        dest='mask_type',
                        default="nodule_masks_spherical",
                        action='store_true')
    parser.add_argument('--ch3', dest='ch3', action='store_true')
    args = parser.parse_args()

    print("""
    
############################################
######### lucanode scan evaluation #########
############################################
""")
    # Create directory for exports if it doesn't exist
    os.makedirs(args.output, exist_ok=True)

    if args.ch3:
        num_channels = 3
    else:
        num_channels = 1
    if args.loss_binary_crossentropy:
        network_shape = [
            *DEFAULT_UNET_SIZE, num_channels, 'binary_crossentropy'
        ]
    else:
        network_shape = [*DEFAULT_UNET_SIZE, num_channels]
    if args.batch_normalization:
        model = Unet(*network_shape)
    else:
        model = UnetSansBN(*network_shape)
    model.load_weights(args.model_weights, by_name=True)

    for subset in tqdm(args.subsets, desc="eval subsets"):
        ann_df = pd.read_csv(args.csv_annotations)
        candidates = []

        with h5py.File(args.dataset, "r") as dataset:
            df = loader.dataset_metadata_as_dataframe(dataset, key='ct_scans')
        df = df[df.subset == subset]
        scan_ids = set(df.seriesuid)
        metrics = []
        for seriesuid in tqdm(scan_ids, desc="eval scans"):
            # Prepare data loader
            df_view = df[df.seriesuid == seriesuid]
            if args.ch3:
                loader_class = loader.NoduleSegmentation3CHSequence
            else:
                loader_class = loader.NoduleSegmentationSequence
            dataset_gen = loader_class(
                args.dataset,
                batch_size=args.batch_size,
                dataframe=df_view,
                epoch_frac=1.0,
                epoch_shuffle=False,
                laplacian=args.use_laplacian,
            )

            # Predict mask
            scan_dice, scan_mask = predict(seriesuid, model, dataset_gen,
                                           args.dataset, args.mask_type)

            # Retrieve candidates
            with h5py.File(args.dataset, "r") as dataset:
                pred_df = nodule_candidates.retrieve_candidates_dataset(
                    seriesuid, dict(dataset["ct_scans"][seriesuid].attrs),
                    scan_mask)
            candidates.append(pred_df)

            # Evaluate candidates
            pred_df = pred_df.reset_index()
            ann_df_view = ann_df[ann_df.seriesuid == seriesuid].reset_index()
            sensitivity, precision, TP, FP, P = evaluate_candidates(
                pred_df, ann_df_view)

            # Save mask
            dataset_filename = Path(args.output) / ("masks_subset%d.h5" %
                                                    (subset, ))
            mode = 'r+' if dataset_filename.exists() else 'w'
            with h5py.File(dataset_filename, mode) as export_ds:
                if seriesuid in export_ds.keys():
                    del export_ds[seriesuid]
                export_ds.create_dataset(seriesuid,
                                         compression="gzip",
                                         data=(scan_mask > 0.5))

            # Save metrics
            scan_metrics = {
                "seriesuid": seriesuid,
                "dice": scan_dice,
                "sensitivity": sensitivity,
                "precision": precision,
                "FP": FP,
                "TP": TP,
                "P": P
            }
            metrics.append(scan_metrics)

        # Export metrics
        columns = [
            "seriesuid", "dice", "sensitivity", "precision", "FP", "TP", "P"
        ]
        metrics_df = pd.DataFrame(metrics, columns=columns)
        metrics_df.to_csv(
            Path(args.output) / ("evaluation_subset%d.csv" % (subset, )))
        pd.concat(candidates, ignore_index=True).to_csv(
            Path(args.output) / ("candidates_subset%d.csv" % (subset, )))

        metrics = "Weights: %s\nMetrics mean for subset%d:\n%s\n\nMetrics variance for subset%d:\n%s" % (
            Path(args.model_weights).name, subset, repr(
                metrics_df.mean()), subset, repr(metrics_df.var()))
        with open(
                Path(args.output) / ("metrics_subset%d.txt" % (subset, )),
                "w") as fd:
            fd.write(metrics)
        print(metrics)
Exemple #3
0
                        help="evaluation batch size")
    args = parser.parse_args()

    print("""
    
############################################
######### lucanode scan evaluation #########
############################################
""")

    network_shape = [*DEFAULT_UNET_SIZE, 1]
    model = Unet(*network_shape)
    model.load_weights(args.model_weights, by_name=True)

    with h5py.File(args.dataset, "r") as dataset:
        df = loader.dataset_metadata_as_dataframe(dataset)
        df = df[df.subset == args.subset]
        scan_ids = set(df.seriesuid)
        metrics = []
        for seriesuid in tqdm(scan_ids):
            df_view = df[df.seriesuid == seriesuid]
            dataset_gen = loader.LungSegmentationSequence(
                dataset,
                batch_size=args.batch_size,
                dataframe=df_view,
                epoch_frac=1.0,
                epoch_shuffle=False)
            scan_metrics = [seriesuid, *model.evaluate_generator(dataset_gen)]
            metrics.append(scan_metrics)
        metrics_df = pd.DataFrame(metrics,
                                  columns=["seriesuid", *model.metrics_names])