def main(args): # First thing, set session in the selected(s) devices: CPU or GPU set_session_in_selected_device(use_GPU_device=True, type_GPU_installed=args.typeGPUinstalled) # ---------- SETTINGS ---------- nameModelsRelPath = args.modelsdir # Get the file list: nameImagesFiles = 'images*' + getFileExtension(FORMATTRAINDATA) nameGroundTruthFiles = 'grndtru*' + getFileExtension(FORMATTRAINDATA) # ---------- SETTINGS ---------- workDirsManager = WorkDirsManager(args.basedir) TrainingDataPath = workDirsManager.getNameExistPath( workDirsManager.getNameTrainingDataPath()) if args.use_restartModel: ModelsPath = workDirsManager.getNameExistPath(args.basedir, nameModelsRelPath) else: ModelsPath = workDirsManager.getNameUpdatePath(args.basedir, nameModelsRelPath) listTrainImagesFiles = findFilesDir(TrainingDataPath, nameImagesFiles) listTrainGroundTruthFiles = findFilesDir(TrainingDataPath, nameGroundTruthFiles) if args.useValidationData: ValidationDataPath = workDirsManager.getNameExistPath( workDirsManager.getNameValidationDataPath()) listValidImagesFiles = findFilesDir(ValidationDataPath, nameImagesFiles) listValidGroundTruthFiles = findFilesDir(ValidationDataPath, nameGroundTruthFiles) if not listValidImagesFiles or not listValidGroundTruthFiles: use_validation_data = False message = "No validation data used for training the model..." CatchWarningException(message) else: use_validation_data = True else: use_validation_data = False # BUILDING MODEL # ---------------------------------------------- print("_" * 30) print("Building model...") print("_" * 30) if args.use_restartModel: initial_epoch = args.epoch_restart args.num_epochs += initial_epoch else: initial_epoch = 0 if TYPE_DNNLIBRARY_USED == 'Keras': if (not args.use_restartModel) or (args.use_restartModel and args.restart_only_weights): model_constructor = DICTAVAILMODELS3D( IMAGES_DIMS_Z_X_Y, tailored_build_model=args.tailored_build_model, num_layers=args.num_layers, num_featmaps_base=args.num_featmaps_base, type_network=args.type_network, type_activate_hidden=args.type_activate_hidden, type_activate_output=args.type_activate_output, type_padding_convol=args.type_padding_convol, is_disable_convol_pooling_lastlayer=args. disable_convol_pooling_lastlayer, isuse_dropout=args.isUse_dropout, isuse_batchnormalize=args.isUse_batchnormalize) optimizer = DICTAVAILOPTIMIZERS(args.optimizer, lr=args.learn_rate) loss_fun = DICTAVAILLOSSFUNS( args.lossfun, is_masks_exclude=args.masksToRegionInterest).loss metrics = [ DICTAVAILMETRICFUNS(imetrics, is_masks_exclude=args.masksToRegionInterest ).get_renamed_compute() for imetrics in args.listmetrics ] model = model_constructor.get_model() # compile model model.compile(optimizer=optimizer, loss=loss_fun, metrics=metrics) # output model summary model.summary() if args.use_restartModel: print("Loading saved weights and restarting...") modelSavedPath = joinpathnames( ModelsPath, 'model_' + args.restart_modelFile + '.hdf5') print("Restarting from file: \'%s\'..." % (modelSavedPath)) model.load_weights(modelSavedPath) else: #args.use_restartModel and args.restart_only_weights: print( "Loading full model: weights, optimizer, loss, metrics ... and restarting..." ) modelSavedPath = joinpathnames( ModelsPath, 'model_' + args.restart_modelFile + '.hdf5') print("Restarting from file: \'%s\'..." % (modelSavedPath)) loss_fun = DICTAVAILLOSSFUNS( args.lossfun, is_masks_exclude=args.masksToRegionInterest).loss metrics = [ DICTAVAILMETRICFUNS(imetrics, is_masks_exclude=args.masksToRegionInterest ).get_renamed_compute() for imetrics in args.listmetrics ] custom_objects = dict( map(lambda fun: (fun.__name__, fun), [loss_fun] + metrics)) # load and compile model model = NeuralNetwork.get_load_saved_model( modelSavedPath, custom_objects=custom_objects) # Callbacks: callbacks_list = [] callbacks_list.append( RecordLossHistory(ModelsPath, [ DICTAVAILMETRICFUNS(imetrics, is_masks_exclude=args.masksToRegionInterest ).get_renamed_compute() for imetrics in args.listmetrics ])) filename = joinpathnames( ModelsPath, 'model_{epoch:02d}_{loss:.5f}_{val_loss:.5f}.hdf5') callbacks_list.append( callbacks.ModelCheckpoint(filename, monitor='loss', verbose=0)) # callbacks_list.append(callbacks.EarlyStopping(monitor='val_loss', patience=10, mode='max')) # output model summary model.summary() elif TYPE_DNNLIBRARY_USED == 'Pytorch': if (not args.use_restartModel) or (args.use_restartModel and args.restart_only_weights): model_net = DICTAVAILMODELS3D(IMAGES_DIMS_Z_X_Y) optimizer = DICTAVAILOPTIMIZERS(args.optimizer, model_net.parameters(), lr=args.learn_rate) loss_fun = DICTAVAILLOSSFUNS( args.lossfun, is_masks_exclude=args.masksToRegionInterest) trainer = Trainer(model_net, optimizer, loss_fun) if args.use_restartModel: print("Loading saved weights and restarting...") modelSavedPath = joinpathnames( ModelsPath, 'model_' + args.restart_modelFile + '.pt') print("Restarting from file: \'%s\'..." % (modelSavedPath)) trainer.load_model_only_weights(modelSavedPath) else: #args.use_restartModel and args.restart_only_weights: print( "Loading full model: weights, optimizer, loss, metrics ... and restarting..." ) modelSavedPath = joinpathnames( ModelsPath, 'model_' + args.restart_modelFile + '.pt') print("Restarting from file: \'%s\'..." % (modelSavedPath)) trainer = Trainer.load_model_full(modelSavedPath) trainer.setup_losshistory_filepath( ModelsPath, isexists_lossfile=args.use_restartModel) trainer.setup_validate_model(freq_validate_model=FREQVALIDATEMODEL) trainer.setup_savemodel_filepath( ModelsPath, type_save_models='full_model', freq_save_intermodels=FREQSAVEINTERMODELS) # output model summary #trainer.get_summary_model() # ---------------------------------------------- # LOADING DATA # ---------------------------------------------- print("-" * 30) print("Loading data...") print("-" * 30) print("Load Training data...") if (args.slidingWindowImages or args.transformationImages or args.elasticDeformationImages): print( "Generate Training images with Batch Generator of Training data..." ) (train_xData, train_yData) = LoadDataManager.loadData_ListFiles( listTrainImagesFiles, listTrainGroundTruthFiles) train_images_generator = getImagesDataGenerator3D( args.slidingWindowImages, args.prop_overlap_Z_X_Y, args.transformationImages, args.elasticDeformationImages) train_batch_data_generator = TrainingBatchDataGenerator( IMAGES_DIMS_Z_X_Y, train_xData, train_yData, train_images_generator, batch_size=args.batch_size, shuffle=SHUFFLETRAINDATA) print("Number volumes: %s. Total Data batches generated: %s..." % (len(listTrainImagesFiles), len(train_batch_data_generator))) else: (train_xData, train_yData ) = LoadDataManagerInBatches(IMAGES_DIMS_Z_X_Y).loadData_ListFiles( listTrainImagesFiles, listTrainGroundTruthFiles) print("Number volumes: %s. Total Data batches generated: %s..." % (len(listTrainImagesFiles), len(train_xData))) if use_validation_data: print("Load Validation data...") if (args.slidingWindowImages or args.transformationImages or args.elasticDeformationImages): print( "Generate Validation images with Batch Generator of Validation data..." ) args.transformationImages = args.transformationImages and args.useTransformOnValidationData args.elasticDeformationImages = args.elasticDeformationImages and args.useTransformOnValidationData (valid_xData, valid_yData) = LoadDataManager.loadData_ListFiles( listValidImagesFiles, listValidGroundTruthFiles) valid_images_generator = getImagesDataGenerator3D( args.slidingWindowImages, args.prop_overlap_Z_X_Y, args.transformationImages, args.elasticDeformationImages) valid_batch_data_generator = TrainingBatchDataGenerator( IMAGES_DIMS_Z_X_Y, valid_xData, valid_yData, valid_images_generator, batch_size=args.batch_size, shuffle=SHUFFLETRAINDATA) validation_data = valid_batch_data_generator print("Number volumes: %s. Total Data batches generated: %s..." % (len(listValidImagesFiles), len(valid_batch_data_generator))) else: (valid_xData, valid_yData) = LoadDataManagerInBatches( IMAGES_DIMS_Z_X_Y).loadData_ListFiles( listValidImagesFiles, listValidGroundTruthFiles) validation_data = (valid_xData, valid_yData) print("Number volumes: %s. Total Data batches generated: %s..." % (len(listTrainImagesFiles), len(valid_xData))) else: validation_data = None # TRAINING MODEL # ---------------------------------------------- print("-" * 30) print("Training model...") print("-" * 30) if TYPE_DNNLIBRARY_USED == 'Keras': if (args.slidingWindowImages or args.transformationImages): model.fit_generator(train_batch_data_generator, nb_epoch=args.num_epochs, steps_per_epoch=args.max_steps_epoch, verbose=1, callbacks=callbacks_list, validation_data=validation_data, shuffle=SHUFFLETRAINDATA, initial_epoch=initial_epoch) else: model.fit(train_xData, train_yData, batch_size=args.batch_size, epochs=args.num_epochs, steps_per_epoch=args.max_steps_epoch, verbose=1, callbacks=callbacks_list, validation_data=validation_data, shuffle=SHUFFLETRAINDATA, initial_epoch=initial_epoch) elif TYPE_DNNLIBRARY_USED == 'Pytorch': trainer.train(train_batch_data_generator, num_epochs=args.num_epochs, max_steps_epoch=args.max_steps_epoch, valid_data_generator=validation_data, initial_epoch=initial_epoch)
def main(args): # ---------- SETTINGS ---------- nameOrigImagesDataRelPath = 'Images_WorkData' nameOrigMasksDataRelPath = 'LumenDistTrans_WorkData' nameOriginImagesFiles = 'images*' + getFileExtension(FORMATTRAINDATA) nameOriginMasksFiles = 'grndtru*' + getFileExtension(FORMATTRAINDATA) # ---------- SETTINGS ---------- workDirsManager = WorkDirsManager(args.basedir) OrigImagesDataPath = workDirsManager.getNameExistPath( workDirsManager.getNameBaseDataPath(), nameOrigImagesDataRelPath) OrigGroundTruthDataPath = workDirsManager.getNameExistPath( workDirsManager.getNameBaseDataPath(), nameOrigMasksDataRelPath) TrainingDataPath = workDirsManager.getNameNewPath( workDirsManager.getNameTrainingDataPath()) ValidationDataPath = workDirsManager.getNameNewPath( workDirsManager.getNameValidationDataPath()) TestingDataPath = workDirsManager.getNameNewPath( workDirsManager.getNameTestingDataPath()) listImagesFiles = findFilesDir(OrigImagesDataPath, nameOriginImagesFiles) listGroundTruthFiles = findFilesDir(OrigGroundTruthDataPath, nameOriginMasksFiles) numImagesFiles = len(listImagesFiles) numGroundTruthFiles = len(listGroundTruthFiles) if (numImagesFiles != numGroundTruthFiles): message = "num image files \'%s\' not equal to num ground-truth files \'%s\'..." % ( numImagesFiles, numGroundTruthFiles) CatchErrorException(message) if (args.distribute_fixed_names): print("Split dataset with Fixed Names...") names_repeated = find_element_repeated_two_indexes_names( NAME_IMAGES_TRAINING, NAME_IMAGES_VALIDATION) names_repeated += find_element_repeated_two_indexes_names( NAME_IMAGES_TRAINING, NAME_IMAGES_TESTING) names_repeated += find_element_repeated_two_indexes_names( NAME_IMAGES_VALIDATION, NAME_IMAGES_TESTING) if names_repeated: message = "found names repeated in list Training / Validation / Testing names: %s" % ( names_repeated) CatchErrorException(message) indexesTraining = find_indexes_names_images_files( NAME_IMAGES_TRAINING, listImagesFiles) indexesValidation = find_indexes_names_images_files( NAME_IMAGES_VALIDATION, listImagesFiles) indexesTesting = find_indexes_names_images_files( NAME_IMAGES_TESTING, listImagesFiles) print( "Training (%s files)/ Validation (%s files)/ Testing (%s files)..." % (len(indexesTraining), len(indexesValidation), len(indexesTesting))) else: numTrainingFiles = int(args.prop_data_training * numImagesFiles) numValidationFiles = int(args.prop_data_validation * numImagesFiles) numTestingFiles = int(args.prop_data_testing * numImagesFiles) print( "Training (%s files)/ Validation (%s files)/ Testing (%s files)..." % (numTrainingFiles, numValidationFiles, numTestingFiles)) if (args.distribute_random): print("Split dataset Randomly...") indexesAllFiles = np.random.choice(range(numImagesFiles), size=numImagesFiles, replace=False) else: print("Split dataset In Order...") indexesAllFiles = range(numImagesFiles) indexesTraining = indexesAllFiles[0:numTrainingFiles] indexesValidation = indexesAllFiles[numTrainingFiles:numTrainingFiles + numValidationFiles] indexesTesting = indexesAllFiles[numTrainingFiles + numValidationFiles::] print("Files assigned to Training Data: \'%s\'" % ([basename(listImagesFiles[index]) for index in indexesTraining])) print("Files assigned to Validation Data: \'%s\'" % ([basename(listImagesFiles[index]) for index in indexesValidation])) print("Files assigned to Testing Data: \'%s\'" % ([basename(listImagesFiles[index]) for index in indexesTesting])) # ******************** TRAINING DATA ******************** for index in indexesTraining: makelink( listImagesFiles[index], joinpathnames(TrainingDataPath, basename(listImagesFiles[index]))) makelink( listGroundTruthFiles[index], joinpathnames(TrainingDataPath, basename(listGroundTruthFiles[index]))) #endfor # ******************** TRAINING DATA ******************** # ******************** VALIDATION DATA ******************** for index in indexesValidation: makelink( listImagesFiles[index], joinpathnames(ValidationDataPath, basename(listImagesFiles[index]))) makelink( listGroundTruthFiles[index], joinpathnames(ValidationDataPath, basename(listGroundTruthFiles[index]))) #endfor # ******************** VALIDATION DATA ******************** # ******************** TESTING DATA ******************** for index in indexesTesting: makelink( listImagesFiles[index], joinpathnames(TestingDataPath, basename(listImagesFiles[index]))) makelink( listGroundTruthFiles[index], joinpathnames(TestingDataPath, basename(listGroundTruthFiles[index])))