示例#1
0
文件: plot_utils.py 项目: lmc00/TFG
def multi_training_plots(timestamps, legend_loc='upper right'):
    """
    Compare the loss and accuracy metrics for a timestamped training.

    Parameters
    ----------
    timestamps : str, or list of strs
        Configuration dict
    legend_loc: str
        Legend position
    """
    if timestamps is str:
        timestamps = [timestamps]

    fig, axs = plt.subplots(2, 2, figsize=(16, 16))
    axs = axs.flatten()

    for ts in timestamps:

        # Set the timestamp
        paths.timestamp = ts

        # Load training statistics
        stats_path = os.path.join(paths.get_stats_dir(), 'stats.json')
        with open(stats_path) as f:
            stats = json.load(f)

        # Load training configuration
        conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
        with open(conf_path) as f:
            conf = json.load(f)

        # Training
        axs[0].plot(stats['epoch'], stats['loss'], label=ts)
        axs[1].plot(stats['epoch'], stats['acc'], label=ts)

        # Validation
        if conf['training']['use_validation']:
            axs[2].plot(stats['epoch'], stats['val_loss'], label=ts)
            axs[3].plot(stats['epoch'], stats['val_acc'], label=ts)

    axs[1].set_ylim([0, 1])
    axs[3].set_ylim([0, 1])

    for i in range(4):
        axs[0].set_xlabel('Epochs')

    axs[0].set_title('Training Loss')
    axs[1].set_title('Training Accuracy')
    axs[2].set_title('Validation Loss')
    axs[3].set_title('Validation Accuracy')

    axs[0].legend(loc=legend_loc)
示例#2
0
文件: model_utils.py 项目: lmc00/TFG
def save_conf(conf):
    """
    Save CONF to a txt file to ease the reading and to a json file to ease the parsing.

    Parameters
    ----------
    conf : 1-level nested dict
    """
    save_dir = paths.get_conf_dir()

    # Save dict as json file
    with open(os.path.join(save_dir, 'conf.json'), 'w') as outfile:
        json.dump(conf, outfile, sort_keys=True, indent=4)

    # Save dict as txt file for easier redability
    txt_file = open(os.path.join(save_dir, 'conf.txt'), 'w')
    txt_file.write("{:<25}{:<30}{:<30} \n".format('group', 'key', 'value'))
    txt_file.write('=' * 75 + '\n')
    for key, val in sorted(conf.items()):
        for g_key, g_val in sorted(val.items()):
            txt_file.write("{:<25}{:<30}{:<15} \n".format(key, g_key, str(g_val)))
        txt_file.write('-' * 75 + '\n')
    txt_file.close()
示例#3
0
#    if location == "no":
#        return False
#use_lo = use_location() #No lo necesitas, porque el método ya está entrenado con localización y guardado en su configuración
TIMESTAMP = input("Indica el timestamp. Sin espacios. Mismo formato que en models: ")                       # timestamp of the model
MODEL_NAME = input("Indica el nombre del modelo que se encuentra en ckpts. Sin espacios: ")                           # model to use to make the prediction
#TOP_K = input("Indica el numero de top K: ")                                               # number of top classes predictions to save
TOP_K = 8
# Set the timestamp
paths.timestamp = TIMESTAMP

# Load the data
print(paths.get_ts_splits_dir())
class_names = load_class_names(splits_dir=paths.get_ts_splits_dir()) # INCISO: Estas son las clases que había en el modelo
# en el momento en el que estrenaste (dado por el timestamp). No las que tienes en data/dataset_files
# Load training configuration
conf_path = os.path.join(paths.get_conf_dir(), 'conf.json')
with open(conf_path) as f:
    conf = json.load(f)
#print(conf(['training']['use_location']))
print(type(conf))
use_lo = conf['training']['use_location']
if use_lo:
    print("Estás usando un modelo con localización, necesitas que los splits de validación, train etc que uses tengan dicha etiqueta") 
# Load the model
#print("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------antes")
#model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME),custom_objects={'customAdam':customAdam },compile=False)
#OJO: uso compile = false porque al parecer los checkpoints intermedios que se guardan 
#están en otro formato que al hacer load_model hace que no se reconozca el customAdam. Como para predecir no es necesaria
#la compilación, y en caso de que lo fuera se puede compilar después, la evitamos de momento.
model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME), custom_objects=utils.get_custom_objects(), compile=False)
#print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++después")