def get_user(): api_client = ApiClient() client = DefaultApi(api_client) user = call_api(lambda: client.get_user())[0] return user
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id: if is_experiment_id(args.id): experiment = call_api( lambda: client.get_experiment(args.id), not_found=lambda: handle_error("Could not find experiment %s" % args.id)) monitor_experiment( experiment, detailed=args.gpu, stream_meta={"experiment_id": experiment.short_id}) elif is_job_id(args.id): job = call_api( lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) monitor_job(job, detailed=args.gpu) else: handle_error("Id is neither an experiment id nor a job id!") else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiments to monitor!') experiment = call_api( lambda: client.get_experiment(experiments[0].short_id)) monitor_experiment(experiment, detailed=args.gpu)
def run(args): api_client = ApiClient(host=API_URL) client = AdminApi(api_client) user = call_api(lambda: client.update_or_create_user(username=args.username, email=args.email))[0] print(user)
def get_project(name): api_client = ApiClient() client = DefaultApi(api_client) for project in client.get_repositories(): if project.name == name: return project handle_error("project not found: %s" % name)
def run_train(args): config = load_config(args.config_file) project_name = config.project user = get_user() revision = push_project(user, project_name, args.config_file) api_client = ApiClient() client = DefaultApi(api_client) experiment = call_api(lambda: client.create_experiment( project_name, revision, kind='train', config=json.dumps(config.train.as_dict()))) if args.logs: stream_experiment_log(experiment) else: print('Started experiment %s in background...' % (experiment.short_id)) if util.has_tensorboard(experiment): tensorboard_job = util.tensorboard_job(experiment) if tensorboard_job: print('TensorBoard: {}'.format( util.tensorboard_job_url(tensorboard_job))) print('Type `riseml logs %s` to connect to log stream.' % (experiment.short_id))
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id: if is_experiment_id(args.id): experiment = call_api( lambda: client.get_experiment(args.id), not_found=lambda: handle_error("Could not find experiment!")) stream_experiment_log(experiment) elif is_job_id(args.id): job = call_api( lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) stream_job_log(job) else: handle_error("Can only show logs for jobs or experiments!") else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiment logs to show!') experiment = call_api( lambda: client.get_experiment(experiments[0].short_id)) stream_experiment_log(experiment)
def create_project(config_file): # create, if not exists create_config(config_file, project_template) name = get_project_name(config_file) api_client = ApiClient() client = DefaultApi(api_client) project = client.create_project(name)[0] print("project created: %s (%s)" % (project.name, project.id))
def run_update(args): api_client = ApiClient() client = AdminApi(api_client) validate_username(args.username) validate_email(args.email) user = call_api( lambda: client.update_user(username=args.username, email=args.email)) print('Updated user {}'.format(user.username)) print(' email: {}'.format(user.email))
def run_create(args): api_client = ApiClient() client = AdminApi(api_client) validate_username(args.username) validate_email(args.email) user = call_api(lambda: client.create_user(username=args.username, email=args.email))[0] print('Created user %s' % user.username) print(' email: %s' % user.email) print(' api_key: %s' % user.api_key_plaintext)
def run_sync(args): api_client = ApiClient() client = AdminApi(api_client) res = call_api(lambda: client.sync_account_info()) if res.name is None: print('You have not registered with an account. ' 'Please run ' + bold('riseml account register')) else: print('Successfully synced account info.' ' Account name: %s' % res.name)
def read_and_register_account_key(): account_key = input('Please enter your account key: ').strip() api_client = ApiClient() client = AdminApi(api_client) res = call_api(lambda: client.update_account(account_key=account_key)) if res.name is None: print('Invalid account key. Please verify that your key is correct ' 'or ask for support via [email protected]. ' 'Your cluster is not registered with an account.') else: print('Registered succesfully! Account name: %s' % res.name)
def run_job(project_name, revision, kind, config): api_client = ApiClient() client = DefaultApi(api_client) try: jobs = client.create_job(project_name, revision, kind=kind, config=config) except ApiException as e: handle_http_error(e.body, e.status) stream_job_log(jobs[0])
def run_display(args): api_client = ApiClient() client = AdminApi(api_client) users = call_api(lambda: client.get_users(username=args.username)) if not users: print('User %s not found.' % args.username) else: user = users[0] print('username: %s' % user.username) print('email: %s' % user.email) print('api_key: %s' % user.api_key_plaintext)
def run_list(args): api_client = ApiClient() client = AdminApi(api_client) users = call_api(lambda: client.get_users()) rows = [] for u in users: rows.append([u.username, u.email, str(u.is_enabled)]) print_table(header=['Username', 'Email', 'Enabled'], min_widths=[12, 6, 9], column_spaces=2, rows=rows)
def run(args): api_client = ApiClient() client = AdminApi(api_client) nodes = call_api(lambda: client.get_nodes()) if args.long: display_long(nodes) elif args.gpus: display_gpus(nodes) else: clusterinfos = call_api(lambda: client.get_cluster_infos()) display_clusterinfos(clusterinfos) print('') display_short(nodes)
def run_upgrade(args): api_client = ApiClient() client = AdminApi(api_client) account = call_api(lambda: client.get_account_info()) if account.key is None: print('You have not registered with an account. ' 'Please run ' + bold('riseml account register')) else: register_url = get_riseml_url() + 'upgrade/%s' % account.key if browser_available(): webbrowser.open_new_tab(register_url) else: print('Please visit this URL and follow instructions' ' to upgrade your account: %s' % register_url)
def run(args): api_client = ApiClient() client = AdminApi(api_client) user = get_user() for i in range(args.num_jobs): job_config = get_job_config(args.request_cpus, args.request_mem, args.force_build_steps) stress_script = get_script(args.num_cpus, args.mem) if i == 0: print('Job configuration:\n\n%s' % job_config) print('Starting job %s of %s to stress %s CPUs and %s MB of memory.' % (i + 1, args.num_jobs, args.num_cpus, args.mem)) start_job(user, args.nodename, job_config, stress_script)
def start_job(user, nodename, job_config, stress_script): config_path = prepare_project_dir(job_config, stress_script) config = load_config(config_path) revision = push_project(user, PROJECT_NAME, config_path) api_client = ApiClient() client = DefaultApi(api_client) node_selector = '' if nodename: node_selector = 'kubernetes.io/hostname: %s' % nodename experiment = call_api(lambda: client.create_experiment( PROJECT_NAME, revision, kind='train', config=json.dumps(config.train.as_dict()), node_selectors=node_selector)) remove_project_dir(config_path)
def run_disable(args): sys.stdout.write("Are you sure you want to disable user %s? [y/n]: " % args.username) def user_exit(): print("Apparently not...") exit(0) try: choice = input() except KeyboardInterrupt: user_exit() if choice.strip() != 'y': user_exit() api_client = ApiClient() client = AdminApi(api_client) call_api(lambda: client.delete_user(username=args.username)) print('User %s disabled.' % args.username)
def run(args): api_client = ApiClient(host=get_api_url()) client = DefaultApi(api_client) if args.ids: if any(not is_experiment_id(experiment_id) for experiment_id in args.ids): handle_error("Can only kill experiments!") for experiment_id in args.ids: kill_experiment(client, experiment_id, args.force) else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiments to kill!') if experiments[0].state in ('FINISHED', 'FAILED', 'KILLED'): handle_error('No experiments to kill!') kill_experiment(client, experiments[0].id, args.force)
def run_register(args): api_client = ApiClient() client = AdminApi(api_client) account = call_api(lambda: client.get_account_info()) if account.key is not None: print('Note: this cluster is already registered with an account. ' 'You can continue and register with another account.') read_and_register_account_key() else: key_exists = read_yes_no('Do you already have an account key') if key_exists: read_and_register_account_key() else: register_url = get_riseml_url( ) + 'register/basic/%s' % account.cluster_id if browser_available(): webbrowser.open_new_tab(register_url) else: print('Please visit this URL and follow instructions' ' to register an account: %s' % register_url) read_and_register_account_key()
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id and util.is_experiment_id(args.id): experiment = util.call_api( lambda: client.get_experiment(args.id), not_found=lambda: handle_error("Could not find experiment!")) if experiment.children: show_experiment_group(experiment) else: show_experiment(experiment) elif args.id and util.is_job_id(args.id): job = util.call_api( lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) show_job(job) elif args.id and util.is_user_id(args.id): query_args = {'user': args.id[1:]} if not args.all: query_args['states'] = 'CREATED|PENDING|STARTING|BUILDING|RUNNING' else: query_args['count'] = args.num_last experiments = util.call_api( lambda: client.get_experiments(**query_args)) show_experiments(experiments, all=args.all, collapsed=not args.long) elif not args.id: query_args = {'all_users': args.all_users} if not args.all: query_args['states'] = 'CREATED|PENDING|STARTING|BUILDING|RUNNING' else: query_args['count'] = args.num_last experiments = util.call_api( lambda: client.get_experiments(**query_args)) show_experiments(experiments, all=args.all, collapsed=not args.long, users=args.all_users) else: handle_error("Id does not identify any RiseML entity!")
def check_api_config(api_url, api_key, timeout=180): print('Trying to login to %s with API key \'%s\' for at most %ss... ' % (api_url, api_key, timeout), end='', flush=True) config = Configuration() old_api_host = config.host old_api_key = config.api_key['api_key'] config.host = api_url config.api_key['api_key'] = api_key api_client = ApiClient() client = AdminApi(api_client) start = time.time() while True: try: cluster_config = client.login_user() print('Success!') config.api_key['api_key'] = old_api_key config.host = old_api_host return cluster_config except ApiException as exc: if exc.reason == 'UNAUTHORIZED': print(exc.status, 'Unauthorized - wrong api key?') sys.exit(1) elif time.time() - start < timeout: time.sleep(1) continue else: print(exc.status, exc.reason) sys.exit(1) except HTTPError as e: if time.time() - start < timeout: time.sleep(1) continue else: print('Unable to connect to %s ' % api_url) # all uncaught http errors goes here print(e.reason) sys.exit(1)
def check_api_config(api_url, api_key, timeout=180): print('Waiting %ss for successful login to %s with API key \'%s\' ...' % (timeout, api_url, api_key)) config = Configuration() old_api_host = config.host old_api_key = config.api_key['api_key'] config.host = api_url config.api_key['api_key'] = api_key api_client = ApiClient() client = AdminApi(api_client) start = time.time() while True: try: cluster_infos = client.get_cluster_infos() cluster_id = get_cluster_id(cluster_infos) print('Success! Cluster ID: %s' % cluster_id) config.api_key['api_key'] = old_api_key config.host = old_api_host return cluster_id except ApiException as exc: if exc.reason == 'UNAUTHORIZED': print(exc.status, 'Unauthorized - wrong api key?') sys.exit(1) elif time.time() - start < timeout: time.sleep(1) continue else: print(exc.status, exc.reason) sys.exit(1) except HTTPError as e: if time.time() - start < timeout: time.sleep(1) continue else: print('Unable to connecto to %s ' % api_url) # all uncaught http errors goes here print(e.reason) sys.exit(1)
def run(args): api_client = ApiClient() client = DefaultApi(api_client) if args.id: if is_experiment_id(args.id): stream_experiment(client, args.id) elif is_job_id(args.id): job = call_api(lambda: client.get_job(args.id), not_found=lambda: handle_error("Could not find job!")) if job.role in ['tf-hrvd-master', 'tf-hrvd-worker']: stream_experiment(client, get_experiment_id(args.id), filter_job=job) else: stream_job_log(job) else: handle_error("Can only show logs for jobs or experiments!") else: experiments = call_api(lambda: client.get_experiments()) if not experiments: handle_error('No experiment logs to show!') experiment = call_api(lambda: client.get_experiment(experiments[0].short_id)) stream_experiment_log(experiment)
def run_info(args): def readable_features(features): names = {'user_management': 'User Management'} return [names.get(f, f) for f in features] api_client = ApiClient() client = AdminApi(api_client) account = call_api(lambda: client.get_account_info()) if account.name is None: print('You have not registered with an account. ' 'Please run ' + bold('riseml account register')) else: backend_info = get_account_info_backend(account.key) print('Name: %s' % account.name) print('Key: %s' % account.key) upgrade_text = '' plan = backend_info['plan'] if plan == 'basic': upgrade_text = ' (run ' + bold( 'riseml account upgrade') + ' to switch)' print('Plan: %s%s' % (plan.title(), upgrade_text)) for feature in readable_features(account.enabled_features): print(' - %s' % feature)