def save_conf(conf):
    """
    Save CONF to a txt file to ease the reading and to a json file to ease the parsing.

    Parameters
    ----------
    conf : 1-level nested dict
    """
    from imgclas import paths
    save_dir = paths.get_conf_dir()

    # Save dict as json file
    with open(os.path.join(save_dir, 'conf.json'), 'w') as outfile:
        json.dump(conf, outfile, sort_keys=True, indent=4)

    # Save dict as txt file for easier redability
    txt_file = open(os.path.join(save_dir, 'conf.txt'), 'w')
    txt_file.write("{:<25}{:<30}{:<30} \n".format('group', 'key', 'value'))
    txt_file.write('=' * 75 + '\n')
    for key, val in sorted(conf.items()):
        for g_key, g_val in sorted(val.items()):
            txt_file.write("{:<25}{:<30}{:<15} \n".format(
                key, g_key, str(g_val)))
        txt_file.write('-' * 75 + '\n')
    txt_file.close()
Esempio n. 2
0
def multi_training_plots(timestamps, legend_loc='upper right'):
    """
    Compare the loss and accuracy metrics for a timestamped training.

    Parameters
    ----------
    timestamps : str, or list of strs
        Configuration dict
    legend_loc: str
        Legend position
    """
    if timestamps is str:
        timestamps = [timestamps]

    fig, axs = plt.subplots(2, 2, figsize=(16, 16))
    axs = axs.flatten()

    for ts in timestamps:

        # Set the timestamp
        paths.timestamp = ts

        # Load training statistics
        stats_path = os.path.join(paths.get_stats_dir(), 'stats.json')
        with open(stats_path) as f:
            stats = json.load(f)

        # Load training configuration
        conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
        with open(conf_path) as f:
            conf = json.load(f)

        # Training
        axs[0].plot(stats['epoch'], stats['loss'], label=ts)
        axs[1].plot(stats['epoch'], stats['acc'], label=ts)

        # Validation
        if conf['training']['use_validation']:
            axs[2].plot(stats['epoch'], stats['val_loss'], label=ts)
            axs[3].plot(stats['epoch'], stats['val_acc'], label=ts)

    axs[1].set_ylim([0, 1])
    axs[3].set_ylim([0, 1])

    for i in range(4):
        axs[0].set_xlabel('Epochs')

    axs[0].set_title('Training Loss')
    axs[1].set_title('Training Accuracy')
    axs[2].set_title('Validation Loss')
    axs[3].set_title('Validation Accuracy')

    axs[0].legend(loc=legend_loc)
Esempio n. 3
0
def load_inference_model(timestamp=None, ckpt_name=None):
    """
    Load a model for prediction.

    Parameters
    ----------
    * timestamp: str
        Name of the timestamp to use. The default is the last timestamp in `./models`.
    * ckpt_name: str
        Name of the checkpoint to use. The default is the last checkpoint in `./models/[timestamp]/ckpts`.
    """
    global loaded_ts, loaded_ckpt
    global graph, model, conf, class_names, class_info

    # Set the timestamp
    timestamp_list = next(os.walk(paths.get_models_dir()))[1]
    timestamp_list = sorted(timestamp_list)
    if not timestamp_list:
        raise Exception(
            "You have no models in your `./models` folder to be used for inference. "
            "Therefore the API can only be used for training.")
    elif timestamp is None:
        timestamp = timestamp_list[-1]
    elif timestamp not in timestamp_list:
        raise ValueError(
            "Invalid timestamp name: {}. Available timestamp names are: {}".
            format(timestamp, timestamp_list))
    paths.timestamp = timestamp
    print('Using TIMESTAMP={}'.format(timestamp))

    # Set the checkpoint model to use to make the prediction
    ckpt_list = os.listdir(paths.get_checkpoints_dir())
    ckpt_list = sorted([name for name in ckpt_list if name.endswith('.h5')])
    if not ckpt_list:
        raise Exception(
            "You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference. "
            .format(timestamp) +
            "Therefore the API can only be used for training.")
    elif ckpt_name is None:
        ckpt_name = ckpt_list[-1]
    elif ckpt_name not in ckpt_list:
        raise ValueError(
            "Invalid checkpoint name: {}. Available checkpoint names are: {}".
            format(ckpt_name, ckpt_list))
    print('Using CKPT_NAME={}'.format(ckpt_name))

    # Clear the previous loaded model
    K.clear_session()

    # Load the class names and info
    splits_dir = paths.get_ts_splits_dir()
    class_names = load_class_names(splits_dir=splits_dir)
    class_info = None
    if 'info.txt' in os.listdir(splits_dir):
        class_info = load_class_info(splits_dir=splits_dir)
        if len(class_info) != len(class_names):
            warnings.warn(
                """The 'classes.txt' file has a different length than the 'info.txt' file.
            If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol.
            The API will run with no info until this is solved.""")
            class_info = None
    if class_info is None:
        class_info = ['' for _ in range(len(class_names))]

    # Load training configuration
    conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
    with open(conf_path) as f:
        conf = json.load(f)
        update_with_saved_conf(conf)

    # Load the model
    model = load_model(os.path.join(paths.get_checkpoints_dir(), ckpt_name),
                       custom_objects=utils.get_custom_objects())
    graph = tf.get_default_graph()

    # Set the model as loaded
    loaded_ts = timestamp
    loaded_ckpt = ckpt_name
# User parameters to set
TIMESTAMP = input("Indica el timestamp. Sin espacios. Mismo formato que en models: ")                       # timestamp of the model
MODEL_NAME = 'final_model.h5'                           # model to use to make the prediction
TOP_K = 2                                               # number of top classes predictions to save

# Set the timestamp
paths.timestamp = TIMESTAMP

# Load the data
print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++antes de class names")
class_names = load_class_names(splits_dir=paths.get_ts_splits_dir()) # INCISO: Estas son las clases que había en el modelo
# en el momento en el que estrenaste (dado por el timestamp). No las que tienes en data/dataset_files
print("----------------------------------------------------------------------despues de class names")
# Load training configuration
conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
with open(conf_path) as f:
    conf = json.load(f)
    
# Load the model
print("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------antes")
model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME))
#model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME), custom_objects=utils.get_custom_objects())
print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++después")
# INCISO: Ahora la parte que continúa está basada en el predicting a datasplit txt file que incluye Ignacio en el notebook
# 3.0 . Esta preparación previa es necesaria para computar la matriz de confusión. 
#
# OJO: ahora lo que le vas a dar para testear el modelo dado por el timestamp SÍ se encuentra en data/dataset_files 
# Y ES CON LO QUE TÚ QUIERES TESTEAR EL MODELO.
SPLIT_NAME = input("Indica el nombre del split con el que evaluas. Es de data/dataset_files. Ejemplos: val train ...: ")
# Load the data
Esempio n. 5
0
def load_inference_model():
    """
    Load a model for prediction.

    If several timestamps are available in `./models` it will load `.models/api` or the last timestamp if `api` is not
    available.
    If several checkpoints are available in `./models/[timestamp]/ckpts` it will load
    `.models/[timestamp]/ckpts/final_model.h5` or the last checkpoint if `final_model.h5` is not available.
    """
    global loaded, graph, model, conf, class_names, class_info

    # Set the timestamp
    timestamps = next(os.walk(paths.get_models_dir()))[1]
    if not timestamps:
        raise BadRequest(
            """You have no models in your `./models` folder to be used for inference.
            Therefore the API can only be used for training.""")
    else:
        if 'api' in timestamps:
            TIMESTAMP = 'api'
        else:
            TIMESTAMP = sorted(timestamps)[-1]
        paths.timestamp = TIMESTAMP
        print('Using TIMESTAMP={}'.format(TIMESTAMP))

        # Set the checkpoint model to use to make the prediction
        ckpts = os.listdir(paths.get_checkpoints_dir())
        if not ckpts:
            raise BadRequest(
                """You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference.
                Therefore the API can only be used for training.""".format(
                    TIMESTAMP))
        else:
            if 'final_model.h5' in ckpts:
                MODEL_NAME = 'final_model.h5'
            else:
                MODEL_NAME = sorted(
                    [name for name in ckpts if name.endswith('*.h5')])[-1]
            print('Using MODEL_NAME={}'.format(MODEL_NAME))

            # Clear the previous loaded model
            K.clear_session()

            # Load the class names and info
            splits_dir = paths.get_ts_splits_dir()
            class_names = load_class_names(splits_dir=splits_dir)
            class_info = None
            if 'info.txt' in os.listdir(splits_dir):
                class_info = load_class_info(splits_dir=splits_dir)
                if len(class_info) != len(class_names):
                    warnings.warn(
                        """The 'classes.txt' file has a different length than the 'info.txt' file.
                    If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol.
                    The API will run with no info until this is solved.""")
                    class_info = None
            if class_info is None:
                class_info = ['' for _ in range(len(class_names))]

            # Load training configuration
            conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
            with open(conf_path) as f:
                conf = json.load(f)

            # Load the model
            model = load_model(os.path.join(paths.get_checkpoints_dir(),
                                            MODEL_NAME),
                               custom_objects=utils.get_custom_objects())
            graph = tf.get_default_graph()

    # Set the model as loaded
    loaded = True