def main(args):
    # First thing, set session in the selected(s) devices: CPU or GPU
    set_session_in_selected_device(use_GPU_device=True,
                                   type_GPU_installed=args.typeGPUinstalled)

    # ---------- SETTINGS ----------
    nameModelsRelPath = args.modelsdir

    # Get the file list:
    nameImagesFiles = 'images*' + getFileExtension(FORMATTRAINDATA)
    nameGroundTruthFiles = 'grndtru*' + getFileExtension(FORMATTRAINDATA)
    # ---------- SETTINGS ----------

    workDirsManager = WorkDirsManager(args.basedir)
    TrainingDataPath = workDirsManager.getNameExistPath(
        workDirsManager.getNameTrainingDataPath())
    if args.use_restartModel:
        ModelsPath = workDirsManager.getNameExistPath(args.basedir,
                                                      nameModelsRelPath)
    else:
        ModelsPath = workDirsManager.getNameUpdatePath(args.basedir,
                                                       nameModelsRelPath)

    listTrainImagesFiles = findFilesDir(TrainingDataPath, nameImagesFiles)
    listTrainGroundTruthFiles = findFilesDir(TrainingDataPath,
                                             nameGroundTruthFiles)

    if args.useValidationData:
        ValidationDataPath = workDirsManager.getNameExistPath(
            workDirsManager.getNameValidationDataPath())

        listValidImagesFiles = findFilesDir(ValidationDataPath,
                                            nameImagesFiles)
        listValidGroundTruthFiles = findFilesDir(ValidationDataPath,
                                                 nameGroundTruthFiles)

        if not listValidImagesFiles or not listValidGroundTruthFiles:
            use_validation_data = False
            message = "No validation data used for training the model..."
            CatchWarningException(message)
        else:
            use_validation_data = True
    else:
        use_validation_data = False

    # BUILDING MODEL
    # ----------------------------------------------
    print("_" * 30)
    print("Building model...")
    print("_" * 30)

    if args.use_restartModel:
        initial_epoch = args.epoch_restart
        args.num_epochs += initial_epoch
    else:
        initial_epoch = 0

    if TYPE_DNNLIBRARY_USED == 'Keras':
        if (not args.use_restartModel) or (args.use_restartModel
                                           and args.restart_only_weights):
            model_constructor = DICTAVAILMODELS3D(
                IMAGES_DIMS_Z_X_Y,
                tailored_build_model=args.tailored_build_model,
                num_layers=args.num_layers,
                num_featmaps_base=args.num_featmaps_base,
                type_network=args.type_network,
                type_activate_hidden=args.type_activate_hidden,
                type_activate_output=args.type_activate_output,
                type_padding_convol=args.type_padding_convol,
                is_disable_convol_pooling_lastlayer=args.
                disable_convol_pooling_lastlayer,
                isuse_dropout=args.isUse_dropout,
                isuse_batchnormalize=args.isUse_batchnormalize)
            optimizer = DICTAVAILOPTIMIZERS(args.optimizer, lr=args.learn_rate)
            loss_fun = DICTAVAILLOSSFUNS(
                args.lossfun, is_masks_exclude=args.masksToRegionInterest).loss
            metrics = [
                DICTAVAILMETRICFUNS(imetrics,
                                    is_masks_exclude=args.masksToRegionInterest
                                    ).get_renamed_compute()
                for imetrics in args.listmetrics
            ]
            model = model_constructor.get_model()
            # compile model
            model.compile(optimizer=optimizer, loss=loss_fun, metrics=metrics)
            # output model summary
            model.summary()

            if args.use_restartModel:
                print("Loading saved weights and restarting...")
                modelSavedPath = joinpathnames(
                    ModelsPath, 'model_' + args.restart_modelFile + '.hdf5')
                print("Restarting from file: \'%s\'..." % (modelSavedPath))
                model.load_weights(modelSavedPath)

        else:  #args.use_restartModel and args.restart_only_weights:
            print(
                "Loading full model: weights, optimizer, loss, metrics ... and restarting..."
            )
            modelSavedPath = joinpathnames(
                ModelsPath, 'model_' + args.restart_modelFile + '.hdf5')
            print("Restarting from file: \'%s\'..." % (modelSavedPath))

            loss_fun = DICTAVAILLOSSFUNS(
                args.lossfun, is_masks_exclude=args.masksToRegionInterest).loss
            metrics = [
                DICTAVAILMETRICFUNS(imetrics,
                                    is_masks_exclude=args.masksToRegionInterest
                                    ).get_renamed_compute()
                for imetrics in args.listmetrics
            ]
            custom_objects = dict(
                map(lambda fun: (fun.__name__, fun), [loss_fun] + metrics))
            # load and compile model
            model = NeuralNetwork.get_load_saved_model(
                modelSavedPath, custom_objects=custom_objects)

        # Callbacks:
        callbacks_list = []
        callbacks_list.append(
            RecordLossHistory(ModelsPath, [
                DICTAVAILMETRICFUNS(imetrics,
                                    is_masks_exclude=args.masksToRegionInterest
                                    ).get_renamed_compute()
                for imetrics in args.listmetrics
            ]))
        filename = joinpathnames(
            ModelsPath, 'model_{epoch:02d}_{loss:.5f}_{val_loss:.5f}.hdf5')
        callbacks_list.append(
            callbacks.ModelCheckpoint(filename, monitor='loss', verbose=0))
        # callbacks_list.append(callbacks.EarlyStopping(monitor='val_loss', patience=10, mode='max'))

        # output model summary
        model.summary()

    elif TYPE_DNNLIBRARY_USED == 'Pytorch':
        if (not args.use_restartModel) or (args.use_restartModel
                                           and args.restart_only_weights):
            model_net = DICTAVAILMODELS3D(IMAGES_DIMS_Z_X_Y)
            optimizer = DICTAVAILOPTIMIZERS(args.optimizer,
                                            model_net.parameters(),
                                            lr=args.learn_rate)
            loss_fun = DICTAVAILLOSSFUNS(
                args.lossfun, is_masks_exclude=args.masksToRegionInterest)
            trainer = Trainer(model_net, optimizer, loss_fun)

            if args.use_restartModel:
                print("Loading saved weights and restarting...")
                modelSavedPath = joinpathnames(
                    ModelsPath, 'model_' + args.restart_modelFile + '.pt')
                print("Restarting from file: \'%s\'..." % (modelSavedPath))
                trainer.load_model_only_weights(modelSavedPath)

        else:  #args.use_restartModel and args.restart_only_weights:
            print(
                "Loading full model: weights, optimizer, loss, metrics ... and restarting..."
            )
            modelSavedPath = joinpathnames(
                ModelsPath, 'model_' + args.restart_modelFile + '.pt')
            print("Restarting from file: \'%s\'..." % (modelSavedPath))
            trainer = Trainer.load_model_full(modelSavedPath)

        trainer.setup_losshistory_filepath(
            ModelsPath, isexists_lossfile=args.use_restartModel)
        trainer.setup_validate_model(freq_validate_model=FREQVALIDATEMODEL)
        trainer.setup_savemodel_filepath(
            ModelsPath,
            type_save_models='full_model',
            freq_save_intermodels=FREQSAVEINTERMODELS)

        # output model summary
        #trainer.get_summary_model()
    # ----------------------------------------------

    # LOADING DATA
    # ----------------------------------------------
    print("-" * 30)
    print("Loading data...")
    print("-" * 30)

    print("Load Training data...")
    if (args.slidingWindowImages or args.transformationImages
            or args.elasticDeformationImages):
        print(
            "Generate Training images with Batch Generator of Training data..."
        )
        (train_xData, train_yData) = LoadDataManager.loadData_ListFiles(
            listTrainImagesFiles, listTrainGroundTruthFiles)
        train_images_generator = getImagesDataGenerator3D(
            args.slidingWindowImages, args.prop_overlap_Z_X_Y,
            args.transformationImages, args.elasticDeformationImages)
        train_batch_data_generator = TrainingBatchDataGenerator(
            IMAGES_DIMS_Z_X_Y,
            train_xData,
            train_yData,
            train_images_generator,
            batch_size=args.batch_size,
            shuffle=SHUFFLETRAINDATA)
        print("Number volumes: %s. Total Data batches generated: %s..." %
              (len(listTrainImagesFiles), len(train_batch_data_generator)))
    else:
        (train_xData, train_yData
         ) = LoadDataManagerInBatches(IMAGES_DIMS_Z_X_Y).loadData_ListFiles(
             listTrainImagesFiles, listTrainGroundTruthFiles)
        print("Number volumes: %s. Total Data batches generated: %s..." %
              (len(listTrainImagesFiles), len(train_xData)))

    if use_validation_data:
        print("Load Validation data...")
        if (args.slidingWindowImages or args.transformationImages
                or args.elasticDeformationImages):
            print(
                "Generate Validation images with Batch Generator of Validation data..."
            )
            args.transformationImages = args.transformationImages and args.useTransformOnValidationData
            args.elasticDeformationImages = args.elasticDeformationImages and args.useTransformOnValidationData
            (valid_xData, valid_yData) = LoadDataManager.loadData_ListFiles(
                listValidImagesFiles, listValidGroundTruthFiles)
            valid_images_generator = getImagesDataGenerator3D(
                args.slidingWindowImages, args.prop_overlap_Z_X_Y,
                args.transformationImages, args.elasticDeformationImages)
            valid_batch_data_generator = TrainingBatchDataGenerator(
                IMAGES_DIMS_Z_X_Y,
                valid_xData,
                valid_yData,
                valid_images_generator,
                batch_size=args.batch_size,
                shuffle=SHUFFLETRAINDATA)
            validation_data = valid_batch_data_generator
            print("Number volumes: %s. Total Data batches generated: %s..." %
                  (len(listValidImagesFiles), len(valid_batch_data_generator)))
        else:
            (valid_xData, valid_yData) = LoadDataManagerInBatches(
                IMAGES_DIMS_Z_X_Y).loadData_ListFiles(
                    listValidImagesFiles, listValidGroundTruthFiles)
            validation_data = (valid_xData, valid_yData)
            print("Number volumes: %s. Total Data batches generated: %s..." %
                  (len(listTrainImagesFiles), len(valid_xData)))
    else:
        validation_data = None

    # TRAINING MODEL
    # ----------------------------------------------
    print("-" * 30)
    print("Training model...")
    print("-" * 30)

    if TYPE_DNNLIBRARY_USED == 'Keras':
        if (args.slidingWindowImages or args.transformationImages):
            model.fit_generator(train_batch_data_generator,
                                nb_epoch=args.num_epochs,
                                steps_per_epoch=args.max_steps_epoch,
                                verbose=1,
                                callbacks=callbacks_list,
                                validation_data=validation_data,
                                shuffle=SHUFFLETRAINDATA,
                                initial_epoch=initial_epoch)
        else:
            model.fit(train_xData,
                      train_yData,
                      batch_size=args.batch_size,
                      epochs=args.num_epochs,
                      steps_per_epoch=args.max_steps_epoch,
                      verbose=1,
                      callbacks=callbacks_list,
                      validation_data=validation_data,
                      shuffle=SHUFFLETRAINDATA,
                      initial_epoch=initial_epoch)

    elif TYPE_DNNLIBRARY_USED == 'Pytorch':
        trainer.train(train_batch_data_generator,
                      num_epochs=args.num_epochs,
                      max_steps_epoch=args.max_steps_epoch,
                      valid_data_generator=validation_data,
                      initial_epoch=initial_epoch)
def main(args):
    # ---------- SETTINGS ----------
    nameOrigImagesDataRelPath = 'Images_WorkData'
    nameOrigMasksDataRelPath = 'LumenDistTrans_WorkData'

    nameOriginImagesFiles = 'images*' + getFileExtension(FORMATTRAINDATA)
    nameOriginMasksFiles = 'grndtru*' + getFileExtension(FORMATTRAINDATA)
    # ---------- SETTINGS ----------

    workDirsManager = WorkDirsManager(args.basedir)

    OrigImagesDataPath = workDirsManager.getNameExistPath(
        workDirsManager.getNameBaseDataPath(), nameOrigImagesDataRelPath)
    OrigGroundTruthDataPath = workDirsManager.getNameExistPath(
        workDirsManager.getNameBaseDataPath(), nameOrigMasksDataRelPath)
    TrainingDataPath = workDirsManager.getNameNewPath(
        workDirsManager.getNameTrainingDataPath())
    ValidationDataPath = workDirsManager.getNameNewPath(
        workDirsManager.getNameValidationDataPath())
    TestingDataPath = workDirsManager.getNameNewPath(
        workDirsManager.getNameTestingDataPath())

    listImagesFiles = findFilesDir(OrigImagesDataPath, nameOriginImagesFiles)
    listGroundTruthFiles = findFilesDir(OrigGroundTruthDataPath,
                                        nameOriginMasksFiles)

    numImagesFiles = len(listImagesFiles)
    numGroundTruthFiles = len(listGroundTruthFiles)

    if (numImagesFiles != numGroundTruthFiles):
        message = "num image files \'%s\' not equal to num ground-truth files \'%s\'..." % (
            numImagesFiles, numGroundTruthFiles)
        CatchErrorException(message)

    if (args.distribute_fixed_names):
        print("Split dataset with Fixed Names...")
        names_repeated = find_element_repeated_two_indexes_names(
            NAME_IMAGES_TRAINING, NAME_IMAGES_VALIDATION)
        names_repeated += find_element_repeated_two_indexes_names(
            NAME_IMAGES_TRAINING, NAME_IMAGES_TESTING)
        names_repeated += find_element_repeated_two_indexes_names(
            NAME_IMAGES_VALIDATION, NAME_IMAGES_TESTING)

        if names_repeated:
            message = "found names repeated in list Training / Validation / Testing names: %s" % (
                names_repeated)
            CatchErrorException(message)

        indexesTraining = find_indexes_names_images_files(
            NAME_IMAGES_TRAINING, listImagesFiles)
        indexesValidation = find_indexes_names_images_files(
            NAME_IMAGES_VALIDATION, listImagesFiles)
        indexesTesting = find_indexes_names_images_files(
            NAME_IMAGES_TESTING, listImagesFiles)
        print(
            "Training (%s files)/ Validation (%s files)/ Testing (%s files)..."
            % (len(indexesTraining), len(indexesValidation),
               len(indexesTesting)))
    else:
        numTrainingFiles = int(args.prop_data_training * numImagesFiles)
        numValidationFiles = int(args.prop_data_validation * numImagesFiles)
        numTestingFiles = int(args.prop_data_testing * numImagesFiles)
        print(
            "Training (%s files)/ Validation (%s files)/ Testing (%s files)..."
            % (numTrainingFiles, numValidationFiles, numTestingFiles))
        if (args.distribute_random):
            print("Split dataset Randomly...")
            indexesAllFiles = np.random.choice(range(numImagesFiles),
                                               size=numImagesFiles,
                                               replace=False)
        else:
            print("Split dataset In Order...")
            indexesAllFiles = range(numImagesFiles)

        indexesTraining = indexesAllFiles[0:numTrainingFiles]
        indexesValidation = indexesAllFiles[numTrainingFiles:numTrainingFiles +
                                            numValidationFiles]
        indexesTesting = indexesAllFiles[numTrainingFiles +
                                         numValidationFiles::]

    print("Files assigned to Training Data: \'%s\'" %
          ([basename(listImagesFiles[index]) for index in indexesTraining]))
    print("Files assigned to Validation Data: \'%s\'" %
          ([basename(listImagesFiles[index]) for index in indexesValidation]))
    print("Files assigned to Testing Data: \'%s\'" %
          ([basename(listImagesFiles[index]) for index in indexesTesting]))

    # ******************** TRAINING DATA ********************
    for index in indexesTraining:
        makelink(
            listImagesFiles[index],
            joinpathnames(TrainingDataPath, basename(listImagesFiles[index])))
        makelink(
            listGroundTruthFiles[index],
            joinpathnames(TrainingDataPath,
                          basename(listGroundTruthFiles[index])))
    #endfor
    # ******************** TRAINING DATA ********************

    # ******************** VALIDATION DATA ********************
    for index in indexesValidation:
        makelink(
            listImagesFiles[index],
            joinpathnames(ValidationDataPath,
                          basename(listImagesFiles[index])))
        makelink(
            listGroundTruthFiles[index],
            joinpathnames(ValidationDataPath,
                          basename(listGroundTruthFiles[index])))
    #endfor
    # ******************** VALIDATION DATA ********************

    # ******************** TESTING DATA ********************
    for index in indexesTesting:
        makelink(
            listImagesFiles[index],
            joinpathnames(TestingDataPath, basename(listImagesFiles[index])))
        makelink(
            listGroundTruthFiles[index],
            joinpathnames(TestingDataPath,
                          basename(listGroundTruthFiles[index])))