コード例 #1
0
ファイル: train.py プロジェクト: theroggy/orthoseg
def train(config_path: Path):
    """
    Run a training session for the config specified.

    Args:
        config_path (Path): Path to the config file to use.
    """
    ##### Init #####
    # Load the config and save in a bunch of global variables zo it
    # is accessible everywhere
    conf.read_orthoseg_config(config_path)

    # Init logging
    log_util.clean_log_dir(
        log_dir=conf.dirs.getpath('log_dir'),
        nb_logfiles_tokeep=conf.logging.getint('nb_logfiles_tokeep'))
    global logger
    logger = log_util.main_log_init(conf.dirs.getpath('log_dir'), __name__)

    # Log start + send email
    message = f"Start train for config {config_path.stem}"
    logger.info(message)
    logger.debug(f"Config used: \n{conf.pformat_config()}")
    email_helper.sendmail(message)

    try:

        # First check if the segment_subject has a valid name
        segment_subject = conf.general['segment_subject']
        if segment_subject == 'MUST_OVERRIDE':
            raise Exception(
                "The segment_subject parameter needs to be overridden in the subject specific config file!!!"
            )
        elif '_' in segment_subject:
            raise Exception(
                f"The segment_subject parameter should not contain '_', so this is invalid: {segment_subject}!!!"
            )

        # Create the output dir's if they don't exist yet...
        for dir in [
                conf.dirs.getpath('project_dir'),
                conf.dirs.getpath('training_dir')
        ]:
            if dir and not dir.exists():
                dir.mkdir()

        ##### If the training data doesn't exist yet, create it #####
        # Get the label input info
        label_files_dict = conf.train.getdict('label_datasources', None)
        label_infos = []
        if label_files_dict is not None:
            for label_file_key in label_files_dict:
                label_file = label_files_dict[label_file_key]
                # Add as LabelInfo objects to list
                label_infos.append(
                    prep.LabelInfo(locations_path=Path(
                        label_file['locations_path']),
                                   polygons_path=Path(label_file['data_path']),
                                   image_layer=label_file['image_layer']))
            if label_infos is None or len(label_infos) == 0:
                raise Exception(
                    f"Parameter label_datasources is defined in config but doesn't contain valid label info!"
                )
        else:
            # Search for the files based on the file name patterns...
            labelpolygons_pattern = conf.train.getpath('labelpolygons_pattern')
            labellocations_pattern = conf.train.getpath(
                'labellocations_pattern')
            label_infos = _search_label_files(labelpolygons_pattern,
                                              labellocations_pattern)
            if label_infos is None or len(label_infos) == 0:
                raise Exception(
                    f"No label files found with patterns {labellocations_pattern} and {labelpolygons_pattern}"
                )

        # Determine the projection of (the first) train layer... it will be used for all!!!
        train_image_layer = label_infos[0].image_layer
        train_projection = conf.image_layers[train_image_layer]['projection']

        # Determine classes
        try:
            classes = conf.train.getdict('classes')

            # If the burn_value property isn't supplied for the classes, add them
            for class_id, (classname) in enumerate(classes):
                if 'burn_value' not in classes[classname]:
                    classes[classname]['burn_value'] = class_id
        except Exception as ex:
            raise Exception(
                f"Error reading classes: {conf.train.get('classes')}") from ex

        # Now create the train datasets (train, validation, test)
        force_model_traindata_id = conf.train.getint(
            'force_model_traindata_id')
        if force_model_traindata_id > -1:
            training_dir = conf.dirs.getpath(
                'training_dir') / f"{force_model_traindata_id:02d}"
            traindata_id = force_model_traindata_id
        else:
            logger.info("Prepare train, validation and test data")
            training_dir, traindata_id = prep.prepare_traindatasets(
                label_infos=label_infos,
                classes=classes,
                image_layers=conf.image_layers,
                training_dir=conf.dirs.getpath('training_dir'),
                labelname_column=conf.train.get('labelname_column'),
                image_pixel_x_size=conf.train.getfloat('image_pixel_x_size'),
                image_pixel_y_size=conf.train.getfloat('image_pixel_y_size'),
                image_pixel_width=conf.train.getint('image_pixel_width'),
                image_pixel_height=conf.train.getint('image_pixel_height'),
                ssl_verify=conf.general['ssl_verify'])

        logger.info(
            f"Traindata dir to use is {training_dir}, with traindata_id: {traindata_id}"
        )
        traindata_dir = training_dir / 'train'
        validationdata_dir = training_dir / 'validation'
        testdata_dir = training_dir / 'test'

        ##### Check if training is needed #####
        # Get hyper parameters from the config
        # TODO: activation_function should probably not be specified!!!!!!
        architectureparams = mh.ArchitectureParams(
            architecture=conf.model['architecture'],
            classes=[classname for classname in classes],
            nb_channels=conf.model.getint('nb_channels'),
            architecture_id=conf.model.getint('architecture_id'),
            activation_function='softmax')
        trainparams = mh.TrainParams(
            trainparams_id=conf.train.getint('trainparams_id'),
            image_augmentations=conf.train.getdict('image_augmentations'),
            mask_augmentations=conf.train.getdict('mask_augmentations'),
            class_weights=[
                classes[classname]['weight'] for classname in classes
            ],
            batch_size=conf.train.getint('batch_size_fit'),
            optimizer=conf.train.get('optimizer'),
            optimizer_params=conf.train.getdict('optimizer_params'),
            loss_function=conf.train.get('loss_function'),
            monitor_metric=conf.train.get('monitor_metric'),
            monitor_metric_mode=conf.train.get('monitor_metric_mode'),
            save_format=conf.train.get('save_format'),
            save_best_only=conf.train.getboolean('save_best_only'),
            save_min_accuracy=conf.train.getfloat('save_min_accuracy'),
            nb_epoch=conf.train.getint('max_epoch'),
            nb_epoch_with_freeze=conf.train.getint('nb_epoch_with_freeze'),
            earlystop_patience=conf.train.getint('earlystop_patience'),
            earlystop_monitor_metric=conf.train.get(
                'earlystop_monitor_metric'),
            earlystop_monitor_metric_mode=conf.train.get(
                'earlystop_monitor_metric_mode'),
            log_tensorboard=conf.train.getboolean('log_tensorboard'),
            log_csv=conf.train.getboolean('log_csv'))

        # Check if there exists already a model for this train dataset + hyperparameters
        model_dir = conf.dirs.getpath('model_dir')
        segment_subject = conf.general['segment_subject']
        best_model_curr_train_version = mh.get_best_model(
            model_dir=model_dir,
            segment_subject=segment_subject,
            traindata_id=traindata_id,
            architecture_id=architectureparams.architecture_id,
            trainparams_id=trainparams.trainparams_id)

        # Determine if training is needed,...
        resume_train = conf.train.getboolean('resume_train')
        if resume_train is False:
            # If no (best) model found, training needed!
            if best_model_curr_train_version is None:
                train_needed = True
            elif conf.train.getboolean('force_train') is True:
                train_needed = True
            else:
                logger.info(
                    "JUST PREDICT, without training: resume_train is false and model found"
                )
                train_needed = False
        else:
            # We want to preload an existing model and models were found
            if best_model_curr_train_version is not None:
                logger.info(
                    f"PRELOAD model and continue TRAINING it: {best_model_curr_train_version['filename']}"
                )
                train_needed = True
            else:
                message = "STOP: preload_existing_model is true but no model was found!"
                logger.error(message)
                raise Exception(message)

        ##### Train!!! #####
        if train_needed is True:

            # If a model already exists, use it to predict (possibly new) training and
            # validation dataset. This way it is possible to have a quick check on errors
            # in (new) added labels in the datasets.

            # Get the current best model that already exists for this subject
            best_recent_model = mh.get_best_model(model_dir=model_dir)
            if best_recent_model is not None:
                try:
                    # TODO: move the hyperparams filename formatting to get_models...
                    logger.info(
                        f"Load model + weights from {best_recent_model['filepath']}"
                    )
                    best_model = mf.load_model(best_recent_model['filepath'],
                                               compile=False)
                    best_hyperparams_path = best_recent_model[
                        'filepath'].parent / f"{best_recent_model['basefilename']}_hyperparams.json"
                    best_hyperparams = mh.HyperParams(
                        path=best_hyperparams_path)
                    logger.info("Loaded model, weights and params")

                    # Prepare output subdir to be used for predictions
                    predict_out_subdir, _ = os.path.splitext(
                        best_recent_model['filename'])

                    # Predict training dataset
                    predicter.predict_dir(
                        model=best_model,
                        input_image_dir=traindata_dir / 'image',
                        output_image_dir=traindata_dir / predict_out_subdir,
                        output_vector_path=None,
                        projection_if_missing=train_projection,
                        input_mask_dir=traindata_dir / 'mask',
                        batch_size=conf.train.getint('batch_size_predict'),
                        evaluate_mode=True,
                        classes=best_hyperparams.architecture.classes,
                        cancel_filepath=conf.files.getpath('cancel_filepath'))

                    # Predict validation dataset
                    predicter.predict_dir(
                        model=best_model,
                        input_image_dir=validationdata_dir / 'image',
                        output_image_dir=validationdata_dir /
                        predict_out_subdir,
                        output_vector_path=None,
                        projection_if_missing=train_projection,
                        input_mask_dir=validationdata_dir / 'mask',
                        batch_size=conf.train.getint('batch_size_predict'),
                        evaluate_mode=True,
                        classes=best_hyperparams.architecture.classes,
                        cancel_filepath=conf.files.getpath('cancel_filepath'))
                    del best_model
                except Exception as ex:
                    logger.warn(
                        f"Exception trying to predict with old model: {ex}")

            # Now we can really start training
            logger.info('Start training')
            model_preload_filepath = None
            if best_model_curr_train_version is not None:
                model_preload_filepath = best_model_curr_train_version[
                    'filepath']
            elif conf.train.getboolean('preload_with_previous_traindata'):
                best_model_for_architecture = mh.get_best_model(
                    model_dir=model_dir, segment_subject=segment_subject)
                if best_model_for_architecture is not None:
                    model_preload_filepath = best_model_for_architecture[
                        'filepath']

            # Combine all hyperparameters in hyperparams object
            hyperparams = mh.HyperParams(architecture=architectureparams,
                                         train=trainparams)

            trainer.train(
                traindata_dir=traindata_dir,
                validationdata_dir=validationdata_dir,
                model_save_dir=model_dir,
                segment_subject=segment_subject,
                traindata_id=traindata_id,
                hyperparams=hyperparams,
                model_preload_filepath=model_preload_filepath,
                image_width=conf.train.getint('image_pixel_width'),
                image_height=conf.train.getint('image_pixel_height'),
                save_augmented_subdir=conf.train.get('save_augmented_subdir'))

            # Now get the best model found during training
            best_model_curr_train_version = mh.get_best_model(
                model_dir=model_dir,
                segment_subject=segment_subject,
                traindata_id=traindata_id)

        # Assert to evade typing warnings
        assert best_model_curr_train_version is not None

        # Now predict on the train,... data
        logger.info(
            f"PREDICT test data with best model: {best_model_curr_train_version['filename']}"
        )

        # Load prediction model...
        logger.info(
            f"Load model + weights from {best_model_curr_train_version['filepath']}"
        )
        model = mf.load_model(best_model_curr_train_version['filepath'],
                              compile=False)
        logger.info("Loaded model + weights")

        # Prepare output subdir to be used for predictions
        predict_out_subdir, _ = os.path.splitext(
            best_model_curr_train_version['filename'])

        # Predict training dataset
        predicter.predict_dir(
            model=model,
            input_image_dir=traindata_dir / 'image',
            output_image_dir=traindata_dir / predict_out_subdir,
            output_vector_path=None,
            projection_if_missing=train_projection,
            input_mask_dir=traindata_dir / 'mask',
            batch_size=conf.train.getint('batch_size_predict'),
            evaluate_mode=True,
            classes=classes,
            cancel_filepath=conf.files.getpath('cancel_filepath'))

        # Predict validation dataset
        predicter.predict_dir(
            model=model,
            input_image_dir=validationdata_dir / 'image',
            output_image_dir=validationdata_dir / predict_out_subdir,
            output_vector_path=None,
            projection_if_missing=train_projection,
            input_mask_dir=validationdata_dir / 'mask',
            batch_size=conf.train.getint('batch_size_predict'),
            evaluate_mode=True,
            classes=classes,
            cancel_filepath=conf.files.getpath('cancel_filepath'))

        # Predict test dataset, if it exists
        if testdata_dir is not None and testdata_dir.exists():
            predicter.predict_dir(
                model=model,
                input_image_dir=testdata_dir / 'image',
                output_image_dir=testdata_dir / predict_out_subdir,
                output_vector_path=None,
                projection_if_missing=train_projection,
                input_mask_dir=testdata_dir / 'mask',
                batch_size=conf.train.getint('batch_size_predict'),
                evaluate_mode=True,
                classes=classes,
                cancel_filepath=conf.files.getpath('cancel_filepath'))

        # Predict extra test dataset with random images in the roi, to add to
        # train and/or validation dataset if inaccuracies are found
        # -> this is very useful to find false positives to improve the datasets
        if conf.dirs.getpath('predictsample_image_input_dir').exists():
            predicter.predict_dir(
                model=model,
                input_image_dir=conf.dirs.getpath(
                    'predictsample_image_input_dir'),
                output_image_dir=conf.dirs.getpath(
                    'predictsample_image_output_basedir') / predict_out_subdir,
                output_vector_path=None,
                projection_if_missing=train_projection,
                batch_size=conf.train.getint('batch_size_predict'),
                evaluate_mode=True,
                classes=classes,
                cancel_filepath=conf.files.getpath('cancel_filepath'))

        # Free resources...
        logger.debug("Free resources")
        if model is not None:
            del model
        kr.backend.clear_session()
        gc.collect()

        # Log and send mail
        message = f"Completed train for config {config_path.stem}"
        logger.info(message)
        email_helper.sendmail(message)
    except Exception as ex:
        message = f"ERROR while running train for task {config_path.stem}"
        logger.exception(message)
        email_helper.sendmail(
            subject=message,
            body=f"Exception: {ex}\n\n {traceback.format_exc()}")
        raise Exception(message) from ex
コード例 #2
0
def predict(config_path: Path):
    """
    Run a prediction for the config specified.

    Args:
        config_path (Path): Path to the config file to use.
    """
    ##### Init #####
    # Load the config and save in a bunch of global variables zo it 
    # is accessible everywhere 
    conf.read_orthoseg_config(config_path)
    
    # Init logging
    log_util.clean_log_dir(
            log_dir=conf.dirs.getpath('log_dir'),
            nb_logfiles_tokeep=conf.logging.getint('nb_logfiles_tokeep'))     
    global logger
    logger = log_util.main_log_init(conf.dirs.getpath('log_dir'), __name__)      
    
    # Log start + send email
    message = f"Start predict for config {config_path.stem}"
    logger.info(message)
    logger.debug(f"Config used: \n{conf.pformat_config()}")    
    email_helper.sendmail(message)
    
    try:
        # Read some config, and check if values are ok
        image_layer = conf.image_layers[conf.predict['image_layer']]
        if image_layer is None:
            raise Exception(f"STOP: image_layer to predict is not specified in config: {image_layer}")
        input_image_dir = conf.dirs.getpath('predict_image_input_dir')
        if not input_image_dir.exists():
            raise Exception(f"STOP: input image dir doesn't exist: {input_image_dir}")

        # TODO: add something to delete old data, predictions???

        # Create base filename of model to use
        # TODO: is force data version the most logical, or rather implement 
        #       force weights file or ?
        traindata_id = None
        force_model_traindata_id = conf.train.getint('force_model_traindata_id')
        if force_model_traindata_id is not None and force_model_traindata_id > -1:
            traindata_id = force_model_traindata_id 
        
        # Get the best model that already exists for this train dataset
        trainparams_id = conf.train.getint('trainparams_id')
        best_model = mh.get_best_model(
                model_dir=conf.dirs.getpath('model_dir'), 
                segment_subject=conf.general['segment_subject'],
                traindata_id=traindata_id,
                trainparams_id=trainparams_id)
        
        # Check if a model was found
        if best_model is None:
            message = f"No model found in model_dir: {conf.dirs.getpath('model_dir')} for traindata_id: {traindata_id}"
            logger.critical(message)
            raise Exception(message)
        else:    
            model_weights_filepath = best_model['filepath']
            logger.info(f"Best model found: {model_weights_filepath}")
        
        # Load the hyperparams of the model
        # TODO: move the hyperparams filename formatting to get_models...
        hyperparams_path = best_model['filepath'].parent / f"{best_model['basefilename']}_hyperparams.json"
        hyperparams = mh.HyperParams(path=hyperparams_path)           
        
        # Prepare output subdir to be used for predictions
        predict_out_subdir = f"{best_model['basefilename']}"
        if trainparams_id > 0:
            predict_out_subdir += f"_{trainparams_id}"
        predict_out_subdir += f"_{best_model['epoch']}"
        
        # Try optimizing model with tensorrt. Not supported on Windows
        model = None
        if os.name != 'nt':
            try:
                # Try import
                from tensorflow.python.compiler.tensorrt import trt_convert as trt

                # Import didn't fail, so optimize model
                logger.info('Tensorrt is available, so use optimized model')
                savedmodel_optim_dir = best_model['filepath'].parent / best_model['filepath'].stem + "_optim"
                if not savedmodel_optim_dir.exists():
                    # If base model not yet in savedmodel format
                    savedmodel_dir = best_model['filepath'].parent / best_model['filepath'].stem
                    if not savedmodel_dir.exists():
                        logger.info(f"SavedModel format not yet available, so load model + weights from {best_model['filepath']}")
                        model = mf.load_model(best_model['filepath'], compile=False)
                        logger.info(f"Now save again as savedmodel to {savedmodel_dir}")
                        tf.saved_model.save(model, str(savedmodel_dir))
                        del model

                    # Now optimize model
                    logger.info(f"Optimize + save model to {savedmodel_optim_dir}")
                    converter = trt.TrtGraphConverterV2(
                            input_saved_model_dir=savedmodel_dir,
                            is_dynamic_op=True,
                            precision_mode='FP16')
                    converter.convert()
                    converter.save(savedmodel_optim_dir)
                
                logger.info(f"Load optimized model + weights from {savedmodel_optim_dir}")
                model = tf.keras.models.load_model(savedmodel_optim_dir)

            except ImportError:
                logger.info('Tensorrt is not available, so load unoptimized model')
        
        # If model isn't loaded yet... load!
        if model is None:
            model = mf.load_model(best_model['filepath'], compile=False)

        # Prepare the model for predicting
        nb_gpu = len(tf.config.experimental.list_physical_devices('GPU'))
        batch_size = conf.predict.getint('batch_size')
        if nb_gpu <= 1:
            model_for_predict = model
            logger.info(f"Predict using single GPU or CPU, with nb_gpu: {nb_gpu}")
        else:
            # If multiple GPU's available, create multi_gpu_model
            try:
                model_for_predict = model
                logger.warn(f"Predict using multiple GPUs NOT IMPLEMENTED AT THE MOMENT")
                
                #logger.info(f"Predict using multiple GPUs: {nb_gpu}, batch size becomes: {batch_size*nb_gpu}")
                #batch_size *= nb_gpu
            except ValueError:
                logger.info("Predict using single GPU or CPU")
                model_for_predict = model

        # Prepare some parameters for gthe postprocessing
        simplify_algorithm = conf.predict.get('simplify_algorithm')
        if simplify_algorithm is not None:
            simplify_algorithm = geofileops.SimplifyAlgorithm[simplify_algorithm]
        prediction_cleanup_params = {
                    "simplify_algorithm": simplify_algorithm,
                    "simplify_tolerance": conf.predict.geteval('simplify_tolerance'),
                    "simplify_lookahead": conf.predict.getint('simplify_lookahead'),
                }

        # Prepare the output dirs/paths
        predict_output_dir = Path(f"{str(conf.dirs.getpath('predict_image_output_basedir'))}_{predict_out_subdir}")
        output_vector_dir = conf.dirs.getpath('output_vector_dir')
        output_vector_name = f"{best_model['basefilename']}_{best_model['epoch']}_{conf.predict['image_layer']}"
        output_vector_path = output_vector_dir / f"{output_vector_name}.gpkg"
        
        # Predict for entire dataset
        nb_parallel = conf.general.getint('nb_parallel')
        predicter.predict_dir(
                model=model_for_predict, # type: ignore
                input_image_dir=input_image_dir,
                output_image_dir=predict_output_dir,
                output_vector_path=output_vector_path,
                classes=hyperparams.architecture.classes,
                prediction_cleanup_params=prediction_cleanup_params,
                border_pixels_to_ignore=conf.predict.getint('image_pixels_overlap'),
                projection_if_missing=image_layer['projection'],
                input_mask_dir=None,
                batch_size=batch_size,
                evaluate_mode=False,
                cancel_filepath=conf.files.getpath('cancel_filepath'),
                nb_parallel_postprocess=nb_parallel)
        
        # Log and send mail
        message = f"Completed predict for config {config_path.stem}"
        logger.info(message)
        email_helper.sendmail(message)
    except Exception as ex:
        message = f"ERROR while running predict for task {config_path.stem}"
        logger.exception(message)
        email_helper.sendmail(subject=message, body=f"Exception: {ex}\n\n {traceback.format_exc()}")
        raise Exception(message) from ex
コード例 #3
0
def postprocess(config_path: Path):
    """
    Postprocess the output of a prediction for the config specified.

    Args:
        config_path (Path): Path to the config file.
    """

    ##### Init #####
    # Load the config and save in a bunch of global variables zo it
    # is accessible everywhere
    conf.read_orthoseg_config(config_path)

    # Init logging
    log_util.clean_log_dir(
        log_dir=conf.dirs.getpath('log_dir'),
        nb_logfiles_tokeep=conf.logging.getint('nb_logfiles_tokeep'))
    global logger
    logger = log_util.main_log_init(conf.dirs.getpath('log_dir'), __name__)

    # Log start + send email
    message = f"Start postprocess for config {config_path.stem}"
    logger.info(message)
    logger.debug(f"Config used: \n{conf.pformat_config()}")
    email_helper.sendmail(message)

    try:

        # Create base filename of model to use
        # TODO: is force data version the most logical, or rather implement
        #       force weights file or ?
        traindata_id = None
        force_model_traindata_id = conf.train.getint(
            'force_model_traindata_id')
        if force_model_traindata_id is not None and force_model_traindata_id > -1:
            traindata_id = force_model_traindata_id

        # Get the best model that already exists for this train dataset
        trainparams_id = conf.train.getint('trainparams_id')
        best_model = mh.get_best_model(
            model_dir=conf.dirs.getpath('model_dir'),
            segment_subject=conf.general['segment_subject'],
            traindata_id=traindata_id,
            trainparams_id=trainparams_id)
        if best_model is None:
            raise Exception(
                f"No best model found in {conf.dirs.getpath('model_dir')}")

        # Input file  the "most recent" prediction result dir for this subject
        output_vector_dir = conf.dirs.getpath('output_vector_dir')
        output_vector_name = f"{best_model['basefilename']}_{best_model['epoch']}_{conf.predict['image_layer']}"
        output_vector_path = output_vector_dir / f"{output_vector_name}.gpkg"

        # Prepare some parameters for the postprocessing
        nb_parallel = conf.general.getint('nb_parallel')

        dissolve = conf.postprocess.getboolean('dissolve')
        dissolve_tiles_path = conf.postprocess.getpath('dissolve_tiles_path')
        simplify_algorithm = conf.postprocess.get('simplify_algorithm')
        if simplify_algorithm is not None:
            simplify_algorithm = geofileops.SimplifyAlgorithm[
                simplify_algorithm]
        simplify_tolerance = conf.postprocess.geteval('simplify_tolerance')
        simplify_lookahead = conf.postprocess.get('simplify_lookahead')
        if simplify_lookahead is not None:
            simplify_lookahead = int(simplify_lookahead)

        ##### Go! #####
        postp.postprocess_predictions(input_path=output_vector_path,
                                      output_path=output_vector_path,
                                      dissolve=dissolve,
                                      dissolve_tiles_path=dissolve_tiles_path,
                                      simplify_algorithm=simplify_algorithm,
                                      simplify_tolerance=simplify_tolerance,
                                      simplify_lookahead=simplify_lookahead,
                                      nb_parallel=nb_parallel,
                                      force=False)

        # Log and send mail
        message = f"Completed postprocess for config {config_path.stem}"
        logger.info(message)
        email_helper.sendmail(message)
    except Exception as ex:
        message = f"ERROR while running postprocess for task {config_path.stem}"
        logger.exception(message)
        email_helper.sendmail(
            subject=message,
            body=f"Exception: {ex}\n\n {traceback.format_exc()}")
        raise Exception(message) from ex
コード例 #4
0
def main():
    
    ##### Interprete arguments #####
    parser = argparse.ArgumentParser(add_help=False)

    # Optional arguments
    optional = parser.add_argument_group('Optional arguments')
    optional.add_argument('-d', '--script_dir',
            help='Directory containing the scripts to run.')
    optional.add_argument('-w', '--watch', action='store_true', default=False,
            help='Watch the directory forever for files getting in it.')
    optional.add_argument('-c', '--config',
            help='Path to a config file with parameters that need to overrule the defaults.')
    
    # Add back help         
    optional.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS,
            help='Show this help message and exit')
    args = parser.parse_args()
    
    ### Init stuff ###
    script_dir = Path(args.script_dir)
    if not script_dir.exists():
        raise Exception(f"script dir {script_dir} does not exist")

    # Load the scriptrunner config
    conf = load_scriptrunner_config(args.config, script_dir)

    # Init logging
    log_util.clean_log_dir(
            log_dir=conf['dirs'].getpath('log_dir'),
            nb_logfiles_tokeep=conf['logging'].getint('nb_logfiles_tokeep'))
    logger = log_util.init_logging_dictConfig(
            logconfig_dict=conf['logging'].getdict('logconfig'),
            log_basedir=conf['dirs'].getpath('script_dir'),
            loggername=__name__)
    
    # Init working dirs
    done_dir = conf['dirs'].getpath('done_dir')
    error_dir = conf['dirs'].getpath('error_dir')

    # Loop over scripts to be ran
    wait_message_printed = False
    while True:

        # List the scripts in the dir
        script_paths = []
        script_patterns = conf['general'].getlist('script_patterns')
        for script_pattern in script_patterns:
            script_paths.extend(list(script_dir.glob(script_pattern)))

        # If no scripts found, sleep or stop...
        if len(script_paths) == 0:
            if args.watch is False:
                logger.info(f"No scripts found (anymore) in {script_dir}, so stop")
                break
            else:
                if wait_message_printed is False:
                    logger.info(f"No scripts to run in {script_dir}, so watch script dir...")
                    wait_message_printed = True
                time.sleep(10)
                continue

        # Get next script alphabetically
        script_path = sorted(script_paths)[0]
        
        try:
            # Run the script and print output in realtime
            wait_message_printed = False
            logger.info(f"Run script {script_path}")
            cmd = [script_path]

            process = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE,
                    stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding='utf-8')
                    #,creationflags=subprocess.CREATE_NO_WINDOW)
            
            still_running = True
            while still_running:
                if process.stdout is not None:
                    output = process.stdout.readline()
                    if output == '' and process.poll() is not None:
                        still_running = False
                    if output:
                        logger.info(output.strip())
            
            # If error code != 0, an error occured
            rc = process.poll()
            if rc != 0:
                # Script gave an error, so move to error dir
                logger.error(f"Script {script_path} gave error return code: {rc}")
                error_dir.mkdir(parents=True, exist_ok=True)
                error_path = error_dir / script_path.name
                if error_path.exists():
                    error_path.unlink()
                script_path.rename(target=error_path)
            else:
                # Move the script to the done dir 
                done_dir.mkdir(parents=True, exist_ok=True)
                done_path = done_dir / script_path.name
                if done_path.exists():
                    done_path.unlink()
                script_path.rename(target=done_path)

        except Exception as ex:
            logger.exception(f"Error running script {script_path}")

            # If the script still exists, move it to error dir
            if script_path.exists():
                error_dir.mkdir(parents=True, exist_ok=True)
                error_path = error_dir / script_path.name
                if error_path.exists():
                    error_path.unlink()
                script_path.rename(target=error_path)
コード例 #5
0
ファイル: load_images.py プロジェクト: theroggy/orthoseg
def load_images(config_path: Path, load_testsample_images: bool = False):
    """
    Load and cache images for a segmentation project.
    
    Args:
        config_path (Path): Path to the projects config file.
        load_testsample_images (bool, optional): True to only load testsample 
            images. Defaults to False.
    """

    ##### Init #####
    # Load the config and save in a bunch of global variables zo it
    # is accessible everywhere
    conf.read_orthoseg_config(config_path)

    # Init logging
    log_util.clean_log_dir(
        log_dir=conf.dirs.getpath('log_dir'),
        nb_logfiles_tokeep=conf.logging.getint('nb_logfiles_tokeep'))
    global logger
    logger = log_util.main_log_init(conf.dirs.getpath('log_dir'), __name__)

    # Log + send email
    message = f"Start load_images for config {config_path.stem}"
    logger.info(message)
    logger.debug(f"Config used: \n{conf.pformat_config()}")
    email_helper.sendmail(message)

    try:
        # Use different setting depending if testsample or all images
        if load_testsample_images:
            output_image_dir = conf.dirs.getpath(
                'predictsample_image_input_dir')

            # Use the same image size as for the training, that is the most
            # convenient to check the quality
            image_pixel_width = conf.train.getint('image_pixel_width')
            image_pixel_height = conf.train.getint('image_pixel_height')
            image_pixel_x_size = conf.train.getfloat('image_pixel_x_size')
            image_pixel_y_size = conf.train.getfloat('image_pixel_y_size')
            image_pixels_overlap = 0
            image_format = ows_util.FORMAT_JPEG

            # To create the testsample, fetch only on every ... images
            column_start = 1
            nb_images_to_skip = 50

        else:
            output_image_dir = conf.dirs.getpath('predict_image_input_dir')

            # Get the image size for the predict
            image_pixel_width = conf.predict.getint('image_pixel_width')
            image_pixel_height = conf.predict.getint('image_pixel_height')
            image_pixel_x_size = conf.predict.getfloat('image_pixel_x_size')
            image_pixel_y_size = conf.predict.getfloat('image_pixel_y_size')
            image_pixels_overlap = conf.predict.getint('image_pixels_overlap')
            image_format = ows_util.FORMAT_JPEG

            # For the real prediction dataset, no skipping obviously...
            column_start = 0
            nb_images_to_skip = 0

        # Get ssl_verify setting
        ssl_verify = conf.general['ssl_verify']
        # Get the download cron schedule
        download_cron_schedule = conf.download['cron_schedule']

        # Get the layer info
        predict_layer = conf.predict['image_layer']
        layersources = conf.image_layers[predict_layer]['layersources']
        nb_concurrent_calls = conf.image_layers[predict_layer][
            'nb_concurrent_calls']
        crs = pyproj.CRS.from_user_input(
            conf.image_layers[predict_layer]['projection'])
        bbox = conf.image_layers[predict_layer]['bbox']
        grid_xmin = conf.image_layers[predict_layer]['grid_xmin']
        grid_ymin = conf.image_layers[predict_layer]['grid_ymin']
        image_pixels_ignore_border = conf.image_layers[predict_layer][
            'image_pixels_ignore_border']
        roi_filepath = conf.image_layers[predict_layer]['roi_filepath']

        # Now we are ready to get the images...
        ows_util.get_images_for_grid(
            layersources=layersources,
            output_image_dir=output_image_dir,
            crs=crs,
            image_gen_bounds=bbox,
            image_gen_roi_filepath=roi_filepath,
            grid_xmin=grid_xmin,
            grid_ymin=grid_ymin,
            image_crs_pixel_x_size=image_pixel_x_size,
            image_crs_pixel_y_size=image_pixel_y_size,
            image_pixel_width=image_pixel_width,
            image_pixel_height=image_pixel_height,
            image_pixels_ignore_border=image_pixels_ignore_border,
            nb_concurrent_calls=nb_concurrent_calls,
            cron_schedule=download_cron_schedule,
            image_format=image_format,
            pixels_overlap=image_pixels_overlap,
            column_start=column_start,
            nb_images_to_skip=nb_images_to_skip,
            ssl_verify=ssl_verify)

        # Log and send mail
        message = f"Completed load_images for config {config_path.stem}"
        logger.info(message)
        email_helper.sendmail(message)
    except Exception as ex:
        message = f"ERROR while running load_images for task {config_path.stem}"
        logger.exception(message)
        email_helper.sendmail(
            subject=message,
            body=f"Exception: {ex}\n\n {traceback.format_exc()}")
        raise Exception(message) from ex