def get_predict_args(): parser = OrderedDict() default_conf = config.CONF default_conf = OrderedDict([('testing', default_conf['testing'])]) # Add options for modelname timestamp = default_conf['testing']['timestamp'] timestamp_list = next(os.walk(paths.get_models_dir()))[1] timestamp_list = sorted(timestamp_list) if not timestamp_list: timestamp['value'] = '' else: timestamp['value'] = timestamp_list[-1] timestamp['choices'] = timestamp_list # Add data and url fields parser['files'] = fields.Field( required=False, missing=None, type="file", data_key="data", location="form", description="Select the image you want to classify.") # Use field.String instead of field.Url because I also want to allow uploading of base 64 encoded data strings parser['urls'] = fields.String( required=False, missing=None, description="Select an URL of the image you want to classify.") # missing action="append" --> append more than one url return populate_parser(parser, default_conf)
def train(user_conf): """ Parameters ---------- user_conf : dict Json dict (created with json.dumps) with the user's configuration parameters that will replace the defaults. Must be loaded with json.loads() For example: user_conf={'num_classes': 'null', 'lr_step_decay': '0.1', 'lr_step_schedule': '[0.7, 0.9]', 'use_early_stopping': 'false'} """ CONF = config.CONF # Update the conf with the user input for group, val in sorted(CONF.items()): for g_key, g_val in sorted(val.items()): g_val['value'] = json.loads(user_conf[g_key]) # Check the configuration try: config.check_conf(conf=CONF) except Exception as e: raise BadRequest(e) CONF = config.conf_dict(conf=CONF) timestamp = datetime.now().strftime('%Y-%m-%d_%H%M%S') config.print_conf_table(CONF) K.clear_session() # remove the model loaded for prediction train_fn(TIMESTAMP=timestamp, CONF=CONF) # Sync with NextCloud folders (if NextCloud is available) try: mount_nextcloud(paths.get_models_dir(), 'ncplants:/models') except Exception as e: print(e)
def train(**args): """ Train an image classifier """ update_with_query_conf(user_args=args) CONF = config.conf_dict timestamp = datetime.now().strftime('%Y-%m-%d_%H%M%S') config.print_conf_table(CONF) K.clear_session() # remove the model loaded for prediction train_fn(TIMESTAMP=timestamp, CONF=CONF) # Sync with NextCloud folders (if NextCloud is available) try: mount_nextcloud(paths.get_models_dir(), 'rshare:/models') except Exception as e: print(e)
def load_inference_model(timestamp=None, ckpt_name=None): """ Load a model for prediction. Parameters ---------- * timestamp: str Name of the timestamp to use. The default is the last timestamp in `./models`. * ckpt_name: str Name of the checkpoint to use. The default is the last checkpoint in `./models/[timestamp]/ckpts`. """ global loaded_ts, loaded_ckpt global graph, model, conf, class_names, class_info # Set the timestamp timestamp_list = next(os.walk(paths.get_models_dir()))[1] timestamp_list = sorted(timestamp_list) if not timestamp_list: raise Exception( "You have no models in your `./models` folder to be used for inference. " "Therefore the API can only be used for training.") elif timestamp is None: timestamp = timestamp_list[-1] elif timestamp not in timestamp_list: raise ValueError( "Invalid timestamp name: {}. Available timestamp names are: {}". format(timestamp, timestamp_list)) paths.timestamp = timestamp print('Using TIMESTAMP={}'.format(timestamp)) # Set the checkpoint model to use to make the prediction ckpt_list = os.listdir(paths.get_checkpoints_dir()) ckpt_list = sorted([name for name in ckpt_list if name.endswith('.h5')]) if not ckpt_list: raise Exception( "You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference. " .format(timestamp) + "Therefore the API can only be used for training.") elif ckpt_name is None: ckpt_name = ckpt_list[-1] elif ckpt_name not in ckpt_list: raise ValueError( "Invalid checkpoint name: {}. Available checkpoint names are: {}". format(ckpt_name, ckpt_list)) print('Using CKPT_NAME={}'.format(ckpt_name)) # Clear the previous loaded model K.clear_session() # Load the class names and info splits_dir = paths.get_ts_splits_dir() class_names = load_class_names(splits_dir=splits_dir) class_info = None if 'info.txt' in os.listdir(splits_dir): class_info = load_class_info(splits_dir=splits_dir) if len(class_info) != len(class_names): warnings.warn( """The 'classes.txt' file has a different length than the 'info.txt' file. If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol. The API will run with no info until this is solved.""") class_info = None if class_info is None: class_info = ['' for _ in range(len(class_names))] # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) update_with_saved_conf(conf) # Load the model model = load_model(os.path.join(paths.get_checkpoints_dir(), ckpt_name), custom_objects=utils.get_custom_objects()) graph = tf.get_default_graph() # Set the model as loaded loaded_ts = timestamp loaded_ckpt = ckpt_name
def load_inference_model(): """ Load a model for prediction. If several timestamps are available in `./models` it will load `.models/api` or the last timestamp if `api` is not available. If several checkpoints are available in `./models/[timestamp]/ckpts` it will load `.models/[timestamp]/ckpts/final_model.h5` or the last checkpoint if `final_model.h5` is not available. """ global loaded, graph, model, conf, class_names, class_info # Set the timestamp timestamps = next(os.walk(paths.get_models_dir()))[1] if not timestamps: raise BadRequest( """You have no models in your `./models` folder to be used for inference. Therefore the API can only be used for training.""") else: if 'api' in timestamps: TIMESTAMP = 'api' else: TIMESTAMP = sorted(timestamps)[-1] paths.timestamp = TIMESTAMP print('Using TIMESTAMP={}'.format(TIMESTAMP)) # Set the checkpoint model to use to make the prediction ckpts = os.listdir(paths.get_checkpoints_dir()) if not ckpts: raise BadRequest( """You have no checkpoints in your `./models/{}/ckpts` folder to be used for inference. Therefore the API can only be used for training.""".format( TIMESTAMP)) else: if 'final_model.h5' in ckpts: MODEL_NAME = 'final_model.h5' else: MODEL_NAME = sorted( [name for name in ckpts if name.endswith('*.h5')])[-1] print('Using MODEL_NAME={}'.format(MODEL_NAME)) # Clear the previous loaded model K.clear_session() # Load the class names and info splits_dir = paths.get_ts_splits_dir() class_names = load_class_names(splits_dir=splits_dir) class_info = None if 'info.txt' in os.listdir(splits_dir): class_info = load_class_info(splits_dir=splits_dir) if len(class_info) != len(class_names): warnings.warn( """The 'classes.txt' file has a different length than the 'info.txt' file. If a class has no information whatsoever you should leave that classes row empty or put a '-' symbol. The API will run with no info until this is solved.""") class_info = None if class_info is None: class_info = ['' for _ in range(len(class_names))] # Load training configuration conf_path = os.path.join(paths.get_conf_dir(), 'conf.json') with open(conf_path) as f: conf = json.load(f) # Load the model model = load_model(os.path.join(paths.get_checkpoints_dir(), MODEL_NAME), custom_objects=utils.get_custom_objects()) graph = tf.get_default_graph() # Set the model as loaded loaded = True