def save_conf(conf): """ Save CONF to a txt file to ease the reading and to a json file to ease the parsing. Parameters ---------- conf : 1-level nested dict """ from imgclas import paths save_dir = paths.get_conf_dir() # Save dict as json file with open(os.path.join(save_dir, 'conf.json'), 'w') as outfile: json.dump(conf, outfile, sort_keys=True, indent=4) # Save dict as txt file for easier redability txt_file = open(os.path.join(save_dir, 'conf.txt'), 'w') txt_file.write("{:<25}{:<30}{:<30} \n".format('group', 'key', 'value')) txt_file.write('=' * 75 + '\n') for key, val in sorted(conf.items()): for g_key, g_val in sorted(val.items()): txt_file.write("{:<25}{:<30}{:<15} \n".format( key, g_key, str(g_val))) txt_file.write('-' * 75 + '\n') txt_file.close()
def multi_training_plots(timestamps, legend_loc='upper right'): """ Compare the loss and accuracy metrics for a timestamped training. Parameters ---------- timestamps : str, or list of strs Configuration dict legend_loc: str Legend position """ if timestamps is str: timestamps = [timestamps] fig, axs = plt.subplots(2, 2, figsize=(16, 16)) axs = axs.flatten() for ts in timestamps: # Set the timestamp paths.timestamp = ts # Load training statistics stats_path = os.path.join(paths.get_stats_dir(), 'stats.json') with open(stats_path) as f: stats = json.load(f) # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) # Training axs[0].plot(stats['epoch'], stats['loss'], label=ts) axs[1].plot(stats['epoch'], stats['acc'], label=ts) # Validation if conf['training']['use_validation']: axs[2].plot(stats['epoch'], stats['val_loss'], label=ts) axs[3].plot(stats['epoch'], stats['val_acc'], label=ts) axs[1].set_ylim([0, 1]) axs[3].set_ylim([0, 1]) for i in range(4): axs[0].set_xlabel('Epochs') axs[0].set_title('Training Loss') axs[1].set_title('Training Accuracy') axs[2].set_title('Validation Loss') axs[3].set_title('Validation Accuracy') axs[0].legend(loc=legend_loc)
def load_inference_model(timestamp=None, ckpt_name=None): """ Load a model for prediction. Parameters ---------- * timestamp: str Name of the timestamp to use. The default is the last timestamp in `./models`. * ckpt_name: str Name of the checkpoint to use. The default is the last checkpoint in `./models/[timestamp]/ckpts`. """ global loaded_ts, loaded_ckpt global graph, model, conf, class_names, class_info # Set the timestamp timestamp_list = next(os.walk(paths.get_models_dir()))[1] timestamp_list = sorted(timestamp_list) if not timestamp_list: raise Exception( "You have no models in your `./models` folder to be used for inference. " "Therefore the API can only be used for training.") elif timestamp is None: timestamp = timestamp_list[-1] elif timestamp not in timestamp_list: raise ValueError( "Invalid timestamp name: {}. Available timestamp names are: {}". format(timestamp, timestamp_list)) paths.timestamp = timestamp print('Using TIMESTAMP={}'.format(timestamp)) # Set the checkpoint model to use to make the prediction ckpt_list = os.listdir(paths.get_checkpoints_dir()) ckpt_list = sorted([name for name in ckpt_list if name.endswith('.h5')]) if not ckpt_list: raise Exception( "You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference. " .format(timestamp) + "Therefore the API can only be used for training.") elif ckpt_name is None: ckpt_name = ckpt_list[-1] elif ckpt_name not in ckpt_list: raise ValueError( "Invalid checkpoint name: {}. Available checkpoint names are: {}". format(ckpt_name, ckpt_list)) print('Using CKPT_NAME={}'.format(ckpt_name)) # Clear the previous loaded model K.clear_session() # Load the class names and info splits_dir = paths.get_ts_splits_dir() class_names = load_class_names(splits_dir=splits_dir) class_info = None if 'info.txt' in os.listdir(splits_dir): class_info = load_class_info(splits_dir=splits_dir) if len(class_info) != len(class_names): warnings.warn( """The 'classes.txt' file has a different length than the 'info.txt' file. If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol. The API will run with no info until this is solved.""") class_info = None if class_info is None: class_info = ['' for _ in range(len(class_names))] # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) update_with_saved_conf(conf) # Load the model model = load_model(os.path.join(paths.get_checkpoints_dir(), ckpt_name), custom_objects=utils.get_custom_objects()) graph = tf.get_default_graph() # Set the model as loaded loaded_ts = timestamp loaded_ckpt = ckpt_name
# User parameters to set TIMESTAMP = input("Indica el timestamp. Sin espacios. Mismo formato que en models: ") # timestamp of the model MODEL_NAME = 'final_model.h5' # model to use to make the prediction TOP_K = 2 # number of top classes predictions to save # Set the timestamp paths.timestamp = TIMESTAMP # Load the data print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++antes de class names") class_names = load_class_names(splits_dir=paths.get_ts_splits_dir()) # INCISO: Estas son las clases que había en el modelo # en el momento en el que estrenaste (dado por el timestamp). No las que tienes en data/dataset_files print("----------------------------------------------------------------------despues de class names") # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) # Load the model print("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------antes") model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME)) #model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME), custom_objects=utils.get_custom_objects()) print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++después") # INCISO: Ahora la parte que continúa está basada en el predicting a datasplit txt file que incluye Ignacio en el notebook # 3.0 . Esta preparación previa es necesaria para computar la matriz de confusión. # # OJO: ahora lo que le vas a dar para testear el modelo dado por el timestamp SÍ se encuentra en data/dataset_files # Y ES CON LO QUE TÚ QUIERES TESTEAR EL MODELO. SPLIT_NAME = input("Indica el nombre del split con el que evaluas. Es de data/dataset_files. Ejemplos: val train ...: ") # Load the data
def load_inference_model(): """ Load a model for prediction. If several timestamps are available in `./models` it will load `.models/api` or the last timestamp if `api` is not available. If several checkpoints are available in `./models/[timestamp]/ckpts` it will load `.models/[timestamp]/ckpts/final_model.h5` or the last checkpoint if `final_model.h5` is not available. """ global loaded, graph, model, conf, class_names, class_info # Set the timestamp timestamps = next(os.walk(paths.get_models_dir()))[1] if not timestamps: raise BadRequest( """You have no models in your `./models` folder to be used for inference. Therefore the API can only be used for training.""") else: if 'api' in timestamps: TIMESTAMP = 'api' else: TIMESTAMP = sorted(timestamps)[-1] paths.timestamp = TIMESTAMP print('Using TIMESTAMP={}'.format(TIMESTAMP)) # Set the checkpoint model to use to make the prediction ckpts = os.listdir(paths.get_checkpoints_dir()) if not ckpts: raise BadRequest( """You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference. Therefore the API can only be used for training.""".format( TIMESTAMP)) else: if 'final_model.h5' in ckpts: MODEL_NAME = 'final_model.h5' else: MODEL_NAME = sorted( [name for name in ckpts if name.endswith('*.h5')])[-1] print('Using MODEL_NAME={}'.format(MODEL_NAME)) # Clear the previous loaded model K.clear_session() # Load the class names and info splits_dir = paths.get_ts_splits_dir() class_names = load_class_names(splits_dir=splits_dir) class_info = None if 'info.txt' in os.listdir(splits_dir): class_info = load_class_info(splits_dir=splits_dir) if len(class_info) != len(class_names): warnings.warn( """The 'classes.txt' file has a different length than the 'info.txt' file. If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol. The API will run with no info until this is solved.""") class_info = None if class_info is None: class_info = ['' for _ in range(len(class_names))] # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) # Load the model model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME), custom_objects=utils.get_custom_objects()) graph = tf.get_default_graph() # Set the model as loaded loaded = True