def train_nodule_segmentation_no_augmentation_normalization_dice( dataset_file, output_weights_file, batch_size=5, num_epochs=10, last_epoch=0, initial_weights=None, ): """Train the network from scratch or from a preexisting set of weights on the dataset""" # Loaders dataset = h5py.File(dataset_file, "r") df = loader.dataset_metadata_as_dataframe(dataset, key='nodule_masks_spherical') df_training = df[df.subset.isin([0, 1, 2, 3, 4, 5, 6, 7]) & df.has_mask] dataset.close() training_loader = loader.NoduleSegmentationSequence(dataset_file, batch_size, dataframe=df_training, epoch_frac=1.0, epoch_shuffle=False) df_validation = df[df.subset.isin([8]) & df.has_mask] validation_loader = loader.NoduleSegmentationSequence( dataset_file, batch_size, dataframe=df_validation, epoch_frac=1.0, epoch_shuffle=False) # Callbacks model_checkpoint = ModelCheckpoint(output_weights_file, monitor='val_loss', verbose=1, save_best_only=True) early_stopping = EarlyStopping(monitor='val_loss', patience=10) history_log = HistoryLog(output_weights_file + ".history") # Setup network network_size = [*DEFAULT_UNET_SIZE, 1, dice_coef_loss] model = Unet(*network_size) if initial_weights: model.load_weights(initial_weights) # Train model.fit_generator( generator=training_loader, epochs=num_epochs, initial_epoch=last_epoch, verbose=1, validation_data=validation_loader, use_multiprocessing=True, workers=4, max_queue_size=20, shuffle=True, callbacks=[model_checkpoint, early_stopping, history_log])
def train_lung_segmentation( dataset_file, output_weights_file, batch_size=5, num_epochs=10, last_epoch=0, initial_weights=None, ): """Train the network from scratch or from a preexisting set of weights on the dataset""" # Loaders training_loader = loader.LungSegmentationSequence(dataset_file, batch_size, epoch_frac=0.1) validation_loader = loader.LungSegmentationSequence(dataset_file, batch_size, subsets={8}, epoch_frac=0.3, epoch_shuffle=False) # Callbacks model_checkpoint = ModelCheckpoint(output_weights_file, monitor='val_loss', verbose=1, save_best_only=True) early_stopping = EarlyStopping(monitor='val_loss', patience=10) history_log = HistoryLog(output_weights_file + ".history") # Setup network network_size = [*DEFAULT_UNET_SIZE, 1] model = Unet(*network_size) if initial_weights: model.load_weights(initial_weights) # Train model.fit_generator( generator=training_loader, epochs=num_epochs, initial_epoch=last_epoch, verbose=1, validation_data=validation_loader, use_multiprocessing=True, workers=4, max_queue_size=20, shuffle=True, callbacks=[model_checkpoint, early_stopping, history_log])
def train_fp_reduction_resnet( model_builder, dataset_file, candidates_file, output_weights_file, batch_size=64, num_epochs=30, last_epoch=0, initial_weights=None, ): df = pd.read_csv(candidates_file) df = df[~df.subset.isin([9])] df = df[df.seriesuid.isin( pd.Series( df.seriesuid.unique()).sample(frac=0.35))].sort_values("seriesuid") with h5py.File(dataset_file, 'r') as dataset: cubes = preprocessing.load_cubes(df, dataset) df = pd.DataFrame({ "cube": cubes, "class": df["class"].tolist(), "subset": df["subset"].tolist() }) # Rebalance dataset to 2:1 no_nodule_len = len(df[df["class"] == 0]) df = pd.concat([ df[df["class"] == 0], df[df["class"] == 1].sample(n=no_nodule_len // 2, replace=True) ], ignore_index=True) df_training = df[df.subset.isin([0, 1, 2, 3, 4, 5, 6, 7])] df_validation = df[df.subset.isin([8])] training_loader = loader.NoduleClassificationSequence( batch_size, dataframe=df_training, do_augmentation=True) validation_loader = loader.NoduleClassificationSequence( batch_size, dataframe=df_validation, do_augmentation=False) # Callbacks model_checkpoint = ModelCheckpoint(output_weights_file, monitor='val_loss', verbose=1, save_best_only=True) early_stopping = EarlyStopping(monitor='val_loss', patience=10) history_log = HistoryLog(output_weights_file + ".history") # Setup network model = model_builder((32, 32, 32, 1), 1) model.compile( optimizer=Adam(lr=1e-3), loss='binary_crossentropy', metrics=['accuracy'], ) if initial_weights: model.load_weights(initial_weights) # Train model.fit_generator( generator=training_loader, epochs=num_epochs, initial_epoch=last_epoch, verbose=1, validation_data=validation_loader, use_multiprocessing=False, workers=8, max_queue_size=20, shuffle=True, callbacks=[model_checkpoint, early_stopping, history_log])