def check_sync_config(rsync_url, timeout=20): print('Checking connection to sync server %s for at most %ss... ' % (rsync_url, timeout), end='', flush=True) start = time.time() while True: sync_cmd = [ get_rsync_path(), '--dry-run', '--timeout=10', '--contimeout=10', '.', rsync_url ] proc = subprocess.Popen(sync_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) res = proc.wait() if res != 0: if time.time() - start < timeout: time.sleep(1) continue else: handle_error('Could not connect to sync server: %s' % proc.stdout.read().decode('utf-8')) else: print('Success!') break
def sync_project(project_root, sync_path): o = urlparse(get_sync_url()) rsync_url = '%s://%s%s/%s' % (o.scheme, o.netloc, o.path, sync_path) exclude_file = os.path.join(project_root, '.risemlignore') sync_cmd = [ get_rsync_path(), '-rlpt', '--exclude=.git', '--exclude=riseml*.yml', '--delete-during', '.', rsync_url ] if os.path.exists(exclude_file): sync_cmd.insert(2, '--exclude-from=%s' % exclude_file) project_size = get_project_size(sync_cmd, project_root) if project_size is not None: num_files, size = project_size warn_project_size(size, num_files) sys.stdout.write('Syncing project (%s, %d files)...' % (get_readable_size(size), num_files)) else: sys.stdout.write('Syncing project...') sys.stdout.flush() proc = subprocess.Popen(sync_cmd, cwd=project_root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) for buf in proc.stdout: stdout.write(buf) stdout.flush() res = proc.wait() if res != 0: handle_error('Push code failed, rsync error', exit_code=res)
def get_project(name): api_client = ApiClient() client = DefaultApi(api_client) for project in client.get_repositories(): if project.name == name: return project handle_error("project not found: %s" % name)
def run_init(args): result = create_config(args.config_file, project_init_template, args.project_name) if not result: handle_error('%s already exists' % args.config_file) else: print('%s successfully created' % args.config_file)
def _generate_project_name(): cwd = os.getcwd() project_name = os.path.basename(cwd) if not project_name: project_name = input('Please type project name: ') if not project_name: handle_error('Invalid project name') return project_name
def main(): parser = argparse.ArgumentParser() parser.add_argument('-v', help="show endpoints", action='store_const', const=True) parser.add_argument('--version', '-V', help="show version", action='version', version='RiseML CLI {}'.format(VERSION)) subparsers = parser.add_subparsers() # user ops add_whoami_parser(subparsers) add_user_parser(subparsers) # system ops add_system_parser(subparsers) add_account_parser(subparsers) # data ops add_ls_parser(subparsers) add_cp_parser(subparsers) add_rm_parser(subparsers) # worklow ops add_init_parser(subparsers) add_train_parser(subparsers) #add_exec_parser(subparsers) add_monitor_parser(subparsers) #add_deploy_parser(subparsers) add_logs_parser(subparsers) add_kill_parser(subparsers) add_status_parser(subparsers) args = parser.parse_args(sys.argv[1:]) if args.v: print('api_url: %s' % get_api_url()) print('sync_url: %s' % get_sync_url()) print('stream_url: %s' % get_stream_url()) print('git_url: %s' % get_git_url()) if hasattr(args, 'run'): if not (config_file_exists() or args.run.__name__ == 'run_login'): handle_error('Client configuration file %s not found' % get_config_file()) try: args.run(args) except HTTPError as e: # all uncaught http errors goes here handle_error(str(e)) except KeyboardInterrupt: print('\nAborting...') else: parser.print_usage()
def stream_stats(url, job_id_stats, stream_meta={}): global monitor_stream stream_connected = False job_ids = list(job_id_stats.keys()) def on_message(ws, message): try: msg = json.loads(message) if msg['type'] == 'utilization': stats = msg['data'] job_id = stats['job_id'] if job_id in job_id_stats: with stats_lock: job_stats = job_id_stats[job_id] job_stats.update(stats) elif msg['type'] == 'state': job_id = msg['job_id'] if job_id in job_id_stats: with stats_lock: job_stats = job_id_stats[job_id] job_stats.update_job_state(msg['state']) except Exception as e: handle_error(traceback.format_exc()) def on_error(ws, e): if isinstance(e, (KeyboardInterrupt, SystemExit)): print_user_exit(stream_meta) else: # all other Exception based stuff goes to `handle_error` handle_error(e) def on_close(ws): time.sleep(2) os._exit(0) def on_open(ws): nonlocal stream_connected stream_connected = True ws = websocket.WebSocketApp(url, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) # FIXME: {'Authorization': os.environ.get('RISEML_APIKEY')} monitor_stream = threading.Thread(target=ws.run_forever) monitor_stream.daemon = True monitor_stream.start() conn_timeout = 10 time.sleep(0.1) while not stream_connected: time.sleep(0.5) conn_timeout -= 1 if not stream_connected: handle_error('Unable to connect to monitor stream')
def _on_error(self, _, error): if isinstance(error, (KeyboardInterrupt, SystemExit)): any_id = self.stream_meta.get('experiment_id') or self.stream_meta.get('job_id') print() # newline after ^C if self.stream_meta.get('experiment_id'): print('Experiment will continue in background') else: print('Job will continue in background') if any_id: print('Type `riseml logs %s` to connect to log stream again' % any_id) else: # all other Exception based stuff goes to `handle_error` handle_error(error)
def run_section(args): config = load_config(args.config_file) project_name = config.project try: deploy_config = config.deploy except AttributeError: handle_error('no `deploy` section in {}'.format(args.config_file)) user = get_user() revision = push_project(user, project_name, args.config_file) run_job(project_name, revision, args.config_section, json.dumps(dict(deploy_config)))
def kill_experiment(client, experiment_id, force): experiment = call_api( lambda: client.kill_experiment(experiment_id, force=force), not_found=lambda: handle_error("Could not find experiment!")) if experiment.children: print("killed series {}".format(experiment.short_id)) else: print("killed experiment {}".format(experiment.short_id))
def entrypoint(): builtins.print = safely_encoded_print(print) if get_environment() not in ['development', 'test']: cluster_id = get_cluster_id() rollbar.init(cluster_id if cluster_id else '00000000-0000-0000-0000-000000000000', get_environment(), endpoint=get_rollbar_endpoint(), root=os.path.dirname(os.path.realpath(__file__))) try: main() except Exception: rollbar.report_exc_info() handle_error("An unexpected error occured.") else: main()
def on_message(ws, message): try: msg = json.loads(message) if msg['type'] == 'utilization': stats = msg['data'] job_id = stats['job_id'] if job_id in job_id_stats: with stats_lock: job_stats = job_id_stats[job_id] job_stats.update(stats) elif msg['type'] == 'state': job_id = msg['job_id'] if job_id in job_id_stats: with stats_lock: job_stats = job_id_stats[job_id] job_stats.update_job_state(msg['state']) except Exception as e: handle_error(traceback.format_exc())
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id: if is_experiment_id(args.id): experiment = call_api( lambda: client.get_experiment(args.id), not_found=lambda: handle_error("Could not find experiment %s" % args.id)) monitor_experiment( experiment, detailed=args.gpu, stream_meta={"experiment_id": experiment.short_id}) elif is_job_id(args.id): job = call_api( lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) monitor_job(job, detailed=args.gpu) else: handle_error("Id is neither an experiment id nor a job id!") else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiments to monitor!') experiment = call_api( lambda: client.get_experiment(experiments[0].short_id)) monitor_experiment(experiment, detailed=args.gpu)
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id and util.is_experiment_id(args.id): experiment = util.call_api( lambda: client.get_experiment(args.id), not_found=lambda: handle_error("Could not find experiment!")) if experiment.children: show_experiment_group(experiment) else: show_experiment(experiment) elif args.id and util.is_job_id(args.id): job = util.call_api( lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) show_job(job) elif args.id and util.is_user_id(args.id): query_args = {'user': args.id[1:]} if not args.all: query_args['states'] = 'CREATED|PENDING|STARTING|BUILDING|RUNNING' else: query_args['count'] = args.num_last experiments = util.call_api( lambda: client.get_experiments(**query_args)) show_experiments(experiments, all=args.all, collapsed=not args.long) elif not args.id: query_args = {'all_users': args.all_users} if not args.all: query_args['states'] = 'CREATED|PENDING|STARTING|BUILDING|RUNNING' else: query_args['count'] = args.num_last experiments = util.call_api( lambda: client.get_experiments(**query_args)) show_experiments(experiments, all=args.all, collapsed=not args.long, users=args.all_users) else: handle_error("Id does not identify any RiseML entity!")
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id: if is_experiment_id(args.id): experiment = call_api( lambda: client.get_experiment(args.id), not_found=lambda: handle_error("Could not find experiment!")) stream_experiment_log(experiment) elif is_job_id(args.id): job = call_api( lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) stream_job_log(job) else: handle_error("Can only show logs for jobs or experiments!") else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiment logs to show!') experiment = call_api( lambda: client.get_experiment(experiments[0].short_id)) stream_experiment_log(experiment)
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id: if is_experiment_id(args.id): stream_experiment(client, args.id) elif is_job_id(args.id): job = call_api(lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) if job.role in ['tf-hrvd-master', 'tf-hrvd-worker']: stream_experiment(client, get_experiment_id(args.id), filter_job=job) else: stream_job_log(job) else: handle_error("Can only show logs for jobs or experiments!") else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiment logs to show!') experiment = call_api(lambda: client.get_experiment(experiments[0].short_id)) stream_experiment_log(experiment)
def load_config(config_file, config_section=None): if not config_exists(config_file): handle_error("%s does not exist" % config_file) try: config = RepositoryConfig.from_yml_file(config_file) except ConfigError as e: handle_error("invalid config {}\n{}".format(config_file, str(e))) return if config_section is None: return config else: try: return getattr(config, config_section) except AttributeError: handle_error("config doesn't contain section for %s" % config_section)
def run(args): api_client = ApiClient(host=get_api_url()) client = DefaultApi(api_client) if args.ids: if any(not is_experiment_id(experiment_id) for experiment_id in args.ids): handle_error("Can only kill experiments!") for experiment_id in args.ids: kill_experiment(client, experiment_id, args.force) else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiments to kill!') if experiments[0].state in ('FINISHED', 'FAILED', 'KILLED'): handle_error('No experiments to kill!') kill_experiment(client, experiments[0].id, args.force)
def call_api(api_fn, not_found=None): try: return api_fn() except ApiException as e: if e.status == 0: raise e elif e.status == 401: handle_error("You are not authorized!") elif e.status == 403: handle_http_error(e.body, e.status) elif e.status == 404 and not_found: not_found() else: handle_http_error(e.body, e.status) except LocationValueError as e: handle_error( "RiseML is not configured! Please run 'riseml user login' first!") except HTTPError as e: handle_error( 'Could not connect to API ({host}:{port}{url}) — {exc_type}'. format(host=e.pool.host, port=e.pool.port, url=e.url, exc_type=e.__class__.__name__))
def on_error(ws, e): if isinstance(e, (KeyboardInterrupt, SystemExit)): print_user_exit(stream_meta) else: # all other Exception based stuff goes to `handle_error` handle_error(e)
def validate_email(email): if '@' not in email: handle_error('Invalid email')
def validate_username(username): if not re.match(r'^[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9]$', username): handle_error( 'Username must start and end with an alphanumeric character and may additionally consist out of hyphens inbetween.' )
def stream_experiment(client, experiment_id, filter_job=None): experiment = call_api(lambda: client.get_experiment(experiment_id), not_found=lambda: handle_error("Could not find experiment!")) stream_experiment_log(experiment, filter_job=filter_job)