示例#1
0
def train(user_conf):
    """
    Parameters
    ----------
    user_conf : dict
        Json dict (created with json.dumps) with the user's configuration parameters that will replace the defaults.
        Must be loaded with json.loads()
        For example:
            user_conf={'num_classes': 'null', 'lr_step_decay': '0.1', 'lr_step_schedule': '[0.7, 0.9]', 'use_early_stopping': 'false'}
    """
    CONF = config.CONF

    # Update the conf with the user input
    for group, val in sorted(CONF.items()):
        for g_key, g_val in sorted(val.items()):
            g_val['value'] = json.loads(user_conf[g_key])

    # Check the configuration
    try:
        config.check_conf(conf=CONF)
    except Exception as e:
        raise BadRequest(e)

    CONF = config.conf_dict(conf=CONF)
    timestamp = datetime.now().strftime('%Y-%m-%d_%H%M%S')

    config.print_conf_table(CONF)
    K.clear_session()  # remove the model loaded for prediction
    train_fn(TIMESTAMP=timestamp, CONF=CONF)

    # Sync with NextCloud folders (if NextCloud is available)
    try:
        mount_nextcloud(paths.get_models_dir(), 'ncplants:/models')
    except Exception as e:
        print(e)
示例#2
0
def train(**args):
    """
    Train an image classifier
    """
    update_with_query_conf(user_args=args)
    CONF = config.conf_dict
    timestamp = datetime.now().strftime('%Y-%m-%d_%H%M%S')
    config.print_conf_table(CONF)
    K.clear_session()  # remove the model loaded for prediction
    train_fn(TIMESTAMP=timestamp, CONF=CONF)

    # Sync with NextCloud folders (if NextCloud is available)
    try:
        mount_nextcloud(paths.get_models_dir(), 'ncplants:/models')
    except Exception as e:
        print(e)
示例#3
0
def load_inference_model():
    """
    Load a model for prediction.

    If several timestamps are available in `./models` it will load `.models/api` or the last timestamp if `api` is not
    available.
    If several checkpoints are available in `./models/[timestamp]/ckpts` it will load
    `.models/[timestamp]/ckpts/final_model.h5` or the last checkpoint if `final_model.h5` is not available.
    """
    global loaded, conf, MODEL_NAME, LABELS_FILE

    # Set the timestamp
    timestamps = next(os.walk(paths.get_models_dir()))[1]
    if not timestamps:
        raise Exception(
            "You have no models in your `./models` folder to be used for inference. "
            "This module does not come with a pretrained model so you have to train a model to use it for prediction."
        )
    else:
        if 'api' in timestamps:
            TIMESTAMP = 'api'
        else:
            TIMESTAMP = sorted(timestamps)[-1]
        paths.timestamp = TIMESTAMP
        print('Using TIMESTAMP={}'.format(TIMESTAMP))

        # Set the checkpoint model to use to make the prediction
        ckpts = os.listdir(paths.get_checkpoints_dir())
        if not ckpts:
            raise Exception(
                "You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference. "
                .format(TIMESTAMP) +
                "Therefore the API can only be used for training.")
        else:
            if 'model.pb' in ckpts:
                MODEL_NAME = 'model.pb'
            else:
                MODEL_NAME = sorted(
                    [name for name in ckpts if name.endswith('*.pb')])[-1]
            print('Using MODEL_NAME={}'.format(MODEL_NAME))

            if 'conv_labels.txt' in ckpts:
                LABELS_FILE = 'conv_labels.txt'
            else:
                LABELS_FILE = sorted(
                    [name for name in ckpts if name.endswith('*.txt')])[-1]
            print('Using LABELS_FILE={}'.format(LABELS_FILE))

            # Clear the previous loaded model
            K.clear_session()

            # Load the class names and info
            ckpts_dir = paths.get_checkpoints_dir()
            MODEL_NAME = os.path.join(ckpts_dir, MODEL_NAME)
            LABELS_FILE = os.path.join(ckpts_dir, LABELS_FILE)

            # Load training configuration
            conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
            with open(conf_path) as f:
                conf = json.load(f)

    # Set the model as loaded
    loaded = True