def multi_training_plots(timestamps, legend_loc='upper right'): """ Compare the loss and accuracy metrics for a timestamped training. Parameters ---------- timestamps : str, or list of strs Configuration dict legend_loc: str Legend position """ if timestamps is str: timestamps = [timestamps] fig, axs = plt.subplots(2, 2, figsize=(16, 16)) axs = axs.flatten() for ts in timestamps: # Set the timestamp paths.timestamp = ts # Load training statistics stats_path = os.path.join(paths.get_stats_dir(), 'stats.json') with open(stats_path) as f: stats = json.load(f) # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) # Training axs[0].plot(stats['epoch'], stats['loss'], label=ts) axs[1].plot(stats['epoch'], stats['acc'], label=ts) # Validation if conf['training']['use_validation']: axs[2].plot(stats['epoch'], stats['val_loss'], label=ts) axs[3].plot(stats['epoch'], stats['val_acc'], label=ts) axs[1].set_ylim([0, 1]) axs[3].set_ylim([0, 1]) for i in range(4): axs[0].set_xlabel('Epochs') axs[0].set_title('Training Loss') axs[1].set_title('Training Accuracy') axs[2].set_title('Validation Loss') axs[3].set_title('Validation Accuracy') axs[0].legend(loc=legend_loc)
def save_conf(conf): """ Save CONF to a txt file to ease the reading and to a json file to ease the parsing. Parameters ---------- conf : 1-level nested dict """ save_dir = paths.get_conf_dir() # Save dict as json file with open(os.path.join(save_dir, 'conf.json'), 'w') as outfile: json.dump(conf, outfile, sort_keys=True, indent=4) # Save dict as txt file for easier redability txt_file = open(os.path.join(save_dir, 'conf.txt'), 'w') txt_file.write("{:<25}{:<30}{:<30} \n".format('group', 'key', 'value')) txt_file.write('=' * 75 + '\n') for key, val in sorted(conf.items()): for g_key, g_val in sorted(val.items()): txt_file.write("{:<25}{:<30}{:<15} \n".format(key, g_key, str(g_val))) txt_file.write('-' * 75 + '\n') txt_file.close()
# if location == "no": # return False #use_lo = use_location() #No lo necesitas, porque el método ya está entrenado con localización y guardado en su configuración TIMESTAMP = input("Indica el timestamp. Sin espacios. Mismo formato que en models: ") # timestamp of the model MODEL_NAME = input("Indica el nombre del modelo que se encuentra en ckpts. Sin espacios: ") # model to use to make the prediction #TOP_K = input("Indica el numero de top K: ") # number of top classes predictions to save TOP_K = 8 # Set the timestamp paths.timestamp = TIMESTAMP # Load the data print(paths.get_ts_splits_dir()) class_names = load_class_names(splits_dir=paths.get_ts_splits_dir()) # INCISO: Estas son las clases que había en el modelo # en el momento en el que estrenaste (dado por el timestamp). No las que tienes en data/dataset_files # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) #print(conf(['training']['use_location'])) print(type(conf)) use_lo = conf['training']['use_location'] if use_lo: print("Estás usando un modelo con localización, necesitas que los splits de validación, train etc que uses tengan dicha etiqueta") # Load the model #print("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------antes") #model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME),custom_objects={'customAdam':customAdam },compile=False) #OJO: uso compile = false porque al parecer los checkpoints intermedios que se guardan #están en otro formato que al hacer load_model hace que no se reconozca el customAdam. Como para predecir no es necesaria #la compilación, y en caso de que lo fuera se puede compilar después, la evitamos de momento. model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME), custom_objects=utils.get_custom_objects(), compile=False) #print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++después")