Ejemplo n.º 1
0
def get_predict_args():

    parser = OrderedDict()
    default_conf = config.CONF
    default_conf = OrderedDict([('testing', default_conf['testing'])])

    # Add options for modelname
    timestamp = default_conf['testing']['timestamp']
    timestamp_list = next(os.walk(paths.get_models_dir()))[1]
    timestamp_list = sorted(timestamp_list)
    if not timestamp_list:
        timestamp['value'] = ''
    else:
        timestamp['value'] = timestamp_list[-1]
        timestamp['choices'] = timestamp_list

    # Add data and url fields
    parser['files'] = fields.Field(
        required=False,
        missing=None,
        type="file",
        data_key="data",
        location="form",
        description="Select the image you want to classify.")

    # Use field.String instead of field.Url because I also want to allow uploading of base 64 encoded data strings
    parser['urls'] = fields.String(
        required=False,
        missing=None,
        description="Select an URL of the image you want to classify.")

    # missing action="append" --> append more than one url

    return populate_parser(parser, default_conf)
Ejemplo n.º 2
0
def train(user_conf):
    """
    Parameters
    ----------
    user_conf : dict
        Json dict (created with json.dumps) with the user's configuration parameters that will replace the defaults.
        Must be loaded with json.loads()
        For example:
            user_conf={'num_classes': 'null', 'lr_step_decay': '0.1', 'lr_step_schedule': '[0.7, 0.9]', 'use_early_stopping': 'false'}
    """
    CONF = config.CONF

    # Update the conf with the user input
    for group, val in sorted(CONF.items()):
        for g_key, g_val in sorted(val.items()):
            g_val['value'] = json.loads(user_conf[g_key])

    # Check the configuration
    try:
        config.check_conf(conf=CONF)
    except Exception as e:
        raise BadRequest(e)

    CONF = config.conf_dict(conf=CONF)
    timestamp = datetime.now().strftime('%Y-%m-%d_%H%M%S')

    config.print_conf_table(CONF)
    K.clear_session()  # remove the model loaded for prediction
    train_fn(TIMESTAMP=timestamp, CONF=CONF)

    # Sync with NextCloud folders (if NextCloud is available)
    try:
        mount_nextcloud(paths.get_models_dir(), 'ncplants:/models')
    except Exception as e:
        print(e)
Ejemplo n.º 3
0
Archivo: api.py Proyecto: lmc00/TFG
def train(**args):
    """
    Train an image classifier
    """
    update_with_query_conf(user_args=args)
    CONF = config.conf_dict
    timestamp = datetime.now().strftime('%Y-%m-%d_%H%M%S')
    config.print_conf_table(CONF)
    K.clear_session()  # remove the model loaded for prediction
    train_fn(TIMESTAMP=timestamp, CONF=CONF)

    # Sync with NextCloud folders (if NextCloud is available)
    try:
        mount_nextcloud(paths.get_models_dir(), 'rshare:/models')
    except Exception as e:
        print(e)
Ejemplo n.º 4
0
def load_inference_model(timestamp=None, ckpt_name=None):
    """
    Load a model for prediction.

    Parameters
    ----------
    * timestamp: str
        Name of the timestamp to use. The default is the last timestamp in `./models`.
    * ckpt_name: str
        Name of the checkpoint to use. The default is the last checkpoint in `./models/[timestamp]/ckpts`.
    """
    global loaded_ts, loaded_ckpt
    global graph, model, conf, class_names, class_info

    # Set the timestamp
    timestamp_list = next(os.walk(paths.get_models_dir()))[1]
    timestamp_list = sorted(timestamp_list)
    if not timestamp_list:
        raise Exception(
            "You have no models in your `./models` folder to be used for inference. "
            "Therefore the API can only be used for training.")
    elif timestamp is None:
        timestamp = timestamp_list[-1]
    elif timestamp not in timestamp_list:
        raise ValueError(
            "Invalid timestamp name: {}. Available timestamp names are: {}".
            format(timestamp, timestamp_list))
    paths.timestamp = timestamp
    print('Using TIMESTAMP={}'.format(timestamp))

    # Set the checkpoint model to use to make the prediction
    ckpt_list = os.listdir(paths.get_checkpoints_dir())
    ckpt_list = sorted([name for name in ckpt_list if name.endswith('.h5')])
    if not ckpt_list:
        raise Exception(
            "You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference. "
            .format(timestamp) +
            "Therefore the API can only be used for training.")
    elif ckpt_name is None:
        ckpt_name = ckpt_list[-1]
    elif ckpt_name not in ckpt_list:
        raise ValueError(
            "Invalid checkpoint name: {}. Available checkpoint names are: {}".
            format(ckpt_name, ckpt_list))
    print('Using CKPT_NAME={}'.format(ckpt_name))

    # Clear the previous loaded model
    K.clear_session()

    # Load the class names and info
    splits_dir = paths.get_ts_splits_dir()
    class_names = load_class_names(splits_dir=splits_dir)
    class_info = None
    if 'info.txt' in os.listdir(splits_dir):
        class_info = load_class_info(splits_dir=splits_dir)
        if len(class_info) != len(class_names):
            warnings.warn(
                """The 'classes.txt' file has a different length than the 'info.txt' file.
            If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol.
            The API will run with no info until this is solved.""")
            class_info = None
    if class_info is None:
        class_info = ['' for _ in range(len(class_names))]

    # Load training configuration
    conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
    with open(conf_path) as f:
        conf = json.load(f)
        update_with_saved_conf(conf)

    # Load the model
    model = load_model(os.path.join(paths.get_checkpoints_dir(), ckpt_name),
                       custom_objects=utils.get_custom_objects())
    graph = tf.get_default_graph()

    # Set the model as loaded
    loaded_ts = timestamp
    loaded_ckpt = ckpt_name
Ejemplo n.º 5
0
def load_inference_model():
    """
    Load a model for prediction.

    If several timestamps are available in `./models` it will load `.models/api` or the last timestamp if `api` is not
    available.
    If several checkpoints are available in `./models/[timestamp]/ckpts` it will load
    `.models/[timestamp]/ckpts/final_model.h5` or the last checkpoint if `final_model.h5` is not available.
    """
    global loaded, graph, model, conf, class_names, class_info

    # Set the timestamp
    timestamps = next(os.walk(paths.get_models_dir()))[1]
    if not timestamps:
        raise BadRequest(
            """You have no models in your `./models` folder to be used for inference.
            Therefore the API can only be used for training.""")
    else:
        if 'api' in timestamps:
            TIMESTAMP = 'api'
        else:
            TIMESTAMP = sorted(timestamps)[-1]
        paths.timestamp = TIMESTAMP
        print('Using TIMESTAMP={}'.format(TIMESTAMP))

        # Set the checkpoint model to use to make the prediction
        ckpts = os.listdir(paths.get_checkpoints_dir())
        if not ckpts:
            raise BadRequest(
                """You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference.
                Therefore the API can only be used for training.""".format(
                    TIMESTAMP))
        else:
            if 'final_model.h5' in ckpts:
                MODEL_NAME = 'final_model.h5'
            else:
                MODEL_NAME = sorted(
                    [name for name in ckpts if name.endswith('*.h5')])[-1]
            print('Using MODEL_NAME={}'.format(MODEL_NAME))

            # Clear the previous loaded model
            K.clear_session()

            # Load the class names and info
            splits_dir = paths.get_ts_splits_dir()
            class_names = load_class_names(splits_dir=splits_dir)
            class_info = None
            if 'info.txt' in os.listdir(splits_dir):
                class_info = load_class_info(splits_dir=splits_dir)
                if len(class_info) != len(class_names):
                    warnings.warn(
                        """The 'classes.txt' file has a different length than the 'info.txt' file.
                    If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol.
                    The API will run with no info until this is solved.""")
                    class_info = None
            if class_info is None:
                class_info = ['' for _ in range(len(class_names))]

            # Load training configuration
            conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
            with open(conf_path) as f:
                conf = json.load(f)

            # Load the model
            model = load_model(os.path.join(paths.get_checkpoints_dir(),
                                            MODEL_NAME),
                               custom_objects=utils.get_custom_objects())
            graph = tf.get_default_graph()

    # Set the model as loaded
    loaded = True