Example #1
0
def train_nodule_segmentation_no_augmentation_normalization_dice(
    dataset_file,
    output_weights_file,
    batch_size=5,
    num_epochs=10,
    last_epoch=0,
    initial_weights=None,
):
    """Train the network from scratch or from a preexisting set of weights on the dataset"""

    # Loaders
    dataset = h5py.File(dataset_file, "r")
    df = loader.dataset_metadata_as_dataframe(dataset,
                                              key='nodule_masks_spherical')
    df_training = df[df.subset.isin([0, 1, 2, 3, 4, 5, 6, 7]) & df.has_mask]
    dataset.close()
    training_loader = loader.NoduleSegmentationSequence(dataset_file,
                                                        batch_size,
                                                        dataframe=df_training,
                                                        epoch_frac=1.0,
                                                        epoch_shuffle=False)
    df_validation = df[df.subset.isin([8]) & df.has_mask]
    validation_loader = loader.NoduleSegmentationSequence(
        dataset_file,
        batch_size,
        dataframe=df_validation,
        epoch_frac=1.0,
        epoch_shuffle=False)

    # Callbacks
    model_checkpoint = ModelCheckpoint(output_weights_file,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True)
    early_stopping = EarlyStopping(monitor='val_loss', patience=10)
    history_log = HistoryLog(output_weights_file + ".history")

    # Setup network
    network_size = [*DEFAULT_UNET_SIZE, 1, dice_coef_loss]
    model = Unet(*network_size)

    if initial_weights:
        model.load_weights(initial_weights)

    # Train
    model.fit_generator(
        generator=training_loader,
        epochs=num_epochs,
        initial_epoch=last_epoch,
        verbose=1,
        validation_data=validation_loader,
        use_multiprocessing=True,
        workers=4,
        max_queue_size=20,
        shuffle=True,
        callbacks=[model_checkpoint, early_stopping, history_log])
Example #2
0
def train_lung_segmentation(
    dataset_file,
    output_weights_file,
    batch_size=5,
    num_epochs=10,
    last_epoch=0,
    initial_weights=None,
):
    """Train the network from scratch or from a preexisting set of weights on the dataset"""

    # Loaders
    training_loader = loader.LungSegmentationSequence(dataset_file,
                                                      batch_size,
                                                      epoch_frac=0.1)
    validation_loader = loader.LungSegmentationSequence(dataset_file,
                                                        batch_size,
                                                        subsets={8},
                                                        epoch_frac=0.3,
                                                        epoch_shuffle=False)

    # Callbacks
    model_checkpoint = ModelCheckpoint(output_weights_file,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True)
    early_stopping = EarlyStopping(monitor='val_loss', patience=10)
    history_log = HistoryLog(output_weights_file + ".history")

    # Setup network
    network_size = [*DEFAULT_UNET_SIZE, 1]
    model = Unet(*network_size)

    if initial_weights:
        model.load_weights(initial_weights)

    # Train
    model.fit_generator(
        generator=training_loader,
        epochs=num_epochs,
        initial_epoch=last_epoch,
        verbose=1,
        validation_data=validation_loader,
        use_multiprocessing=True,
        workers=4,
        max_queue_size=20,
        shuffle=True,
        callbacks=[model_checkpoint, early_stopping, history_log])
Example #3
0
def train_fp_reduction_resnet(
    model_builder,
    dataset_file,
    candidates_file,
    output_weights_file,
    batch_size=64,
    num_epochs=30,
    last_epoch=0,
    initial_weights=None,
):
    df = pd.read_csv(candidates_file)
    df = df[~df.subset.isin([9])]
    df = df[df.seriesuid.isin(
        pd.Series(
            df.seriesuid.unique()).sample(frac=0.35))].sort_values("seriesuid")
    with h5py.File(dataset_file, 'r') as dataset:
        cubes = preprocessing.load_cubes(df, dataset)
        df = pd.DataFrame({
            "cube": cubes,
            "class": df["class"].tolist(),
            "subset": df["subset"].tolist()
        })
    # Rebalance dataset to 2:1
    no_nodule_len = len(df[df["class"] == 0])
    df = pd.concat([
        df[df["class"] == 0], df[df["class"] == 1].sample(n=no_nodule_len // 2,
                                                          replace=True)
    ],
                   ignore_index=True)
    df_training = df[df.subset.isin([0, 1, 2, 3, 4, 5, 6, 7])]
    df_validation = df[df.subset.isin([8])]

    training_loader = loader.NoduleClassificationSequence(
        batch_size, dataframe=df_training, do_augmentation=True)
    validation_loader = loader.NoduleClassificationSequence(
        batch_size, dataframe=df_validation, do_augmentation=False)

    # Callbacks
    model_checkpoint = ModelCheckpoint(output_weights_file,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True)
    early_stopping = EarlyStopping(monitor='val_loss', patience=10)
    history_log = HistoryLog(output_weights_file + ".history")

    # Setup network

    model = model_builder((32, 32, 32, 1), 1)
    model.compile(
        optimizer=Adam(lr=1e-3),
        loss='binary_crossentropy',
        metrics=['accuracy'],
    )

    if initial_weights:
        model.load_weights(initial_weights)

    # Train
    model.fit_generator(
        generator=training_loader,
        epochs=num_epochs,
        initial_epoch=last_epoch,
        verbose=1,
        validation_data=validation_loader,
        use_multiprocessing=False,
        workers=8,
        max_queue_size=20,
        shuffle=True,
        callbacks=[model_checkpoint, early_stopping, history_log])