def main(self, args): from aetros import keras_model_utils import aetros.const from aetros.backend import JobBackend from aetros.Trainer import Trainer parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, prog=aetros.const.__prog__ + ' upload-weights') parser.add_argument('id', help='model name or job id') parser.add_argument('weights', help="Weights path") parser.add_argument( '--api-key', help="Secure key. Alternatively use API_KEY environment variable.") parser.add_argument( '--kpi', help="You can overwrite or set the KPI for this job") parser.add_argument( '--latest', action="store_true", help="Instead of best epoch we upload latest weights.") parsed_args = parser.parse_args(args) if not parsed_args.id or not parsed_args.weights: parser.print_help() return job_backend = JobBackend(api_key=parsed_args.api_key) if '/' in parsed_args.id and '@' not in parsed_args.id: job_backend.create(parsed_args.id) job_backend.load(parsed_args.id) if job_backend.job is None: raise Exception("Job not found") weights_path = parsed_args.weights if not os.path.exists(weights_path): raise Exception('Weights file does not exist in ' + weights_path) print("Uploading weights to %s of %s ..." % (job_backend.job_id, job_backend.model_id)) job_backend.upload_weights( 'weights.hdf5', weights_path, float(parsed_args.kpi) if parsed_args.kpi else None)
def main(self, args): from aetros import keras_model_utils import aetros.const from aetros.backend import JobBackend from aetros.logger import GeneralLogger from aetros.Trainer import Trainer parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, prog=aetros.const.__prog__ + ' upload-weights') parser.add_argument('id', nargs='?', help='model name or job id') parser.add_argument( '--secure-key', help="Secure key. Alternatively use API_KEY environment varibale.") parser.add_argument( '--weights', help= "Weights path. Per default we try to find it in the ./weights/ folder." ) parser.add_argument( '--accuracy', help= "If you specified model name, you should also specify the accuracy this weights got." ) parser.add_argument( '--latest', action="store_true", help="Instead of best epoch we upload latest weights.") parsed_args = parser.parse_args(args) job_backend = JobBackend(api_token=parsed_args.secure_key) if '/' in parsed_args.id and '@' not in parsed_args.id: job_backend.create(parsed_args.id) job_backend.load(parsed_args.id) if job_backend.job is None: raise Exception("Job not found") job_model = job_backend.get_job_model() weights_path = job_model.get_weights_filepath_best() if parsed_args.weights: weights_path = parsed_args.weights print(("Validate weights in %s ..." % (weights_path, ))) keras_model_utils.job_prepare(job_model) general_logger = GeneralLogger() trainer = Trainer(job_backend, general_logger) job_model.set_input_shape(trainer) print("Loading model ...") model_provider = job_model.get_model_provider() model = model_provider.get_model(trainer) loss = model_provider.get_loss(trainer) optimizer = model_provider.get_optimizer(trainer) print("Compiling ...") model_provider.compile(trainer, model, loss, optimizer) print(("Validate weights %s ..." % (weights_path, ))) job_model.load_weights(model, weights_path) print("Validated.") print("Uploading weights to %s of %s ..." % (job_backend.job_id, job_backend.model_id)) job_backend.upload_weights( 'best.hdf5', weights_path, float(parsed_args.accuracy) if parsed_args.accuracy else None) print("Done")