def root(self, name, value): #console.print("setting root name={} to value={}".format(name, value)) if name == "help": pass # will be handled later elif name == "console": console.set_level(value) elif name == "stack-trace": utils.show_stack_trace = value elif name == "new": if value and process_utils.can_create_console_window(): cmd = qfe.current_dispatcher.dispatch_cmd echo_cmd = "xt " + cmd.replace("--new", "--echo", 1) process_utils.run_cmd_in_new_console(echo_cmd) errors.early_exit_without_error() elif name == "echo": if value: cmd = qfe.current_dispatcher.dispatch_cmd console.print("xt " + cmd, flush=True) elif name == "quick-start": pass # was already handled elif name == "prep": self.prep_machine_for_controller() else: errors.syntax_error("unrecognized root flag=" + name)
def _list_directories(self, container, path, wc_target, subdirs=0): console.diag( "_list_directories: container={}, path={}, wc_target={}, subdirs={}" .format(container, path, wc_target, subdirs)) service_name = self.provider.get_service_name() dd = {"store_name": "XT Store ({})".format(service_name)} #console.print("dd=", dd) if not container: # get a list of all containers is a special case if path: errors.syntax_error( "path can not be set when the container is set to '/'") folder, folder_names = self._get_root_folders() folders = [folder] if subdirs: base_path = "" for ws_name in folder_names: # get blobs from AZURE console.diag("reading blobs for ws={}".format(ws_name)) blobs = self.provider.list_blobs(ws_name, path=None, return_names=False) blobs = list(blobs) ws_folders = self._build_folders_from_blobs( blobs, ws_name, base_path, subdirs) folders += ws_folders else: # get blobs from AZURE actual_path = path if path else None blobs = self.provider.list_blobs(container, path=actual_path, return_names=False) blobs = list(blobs) if wc_target: # apply filter blobs = [ blob for blob in blobs if fnmatch(blob.name, wc_target) ] console.diag("list_blobs returned: len(blobs)={}".format( len(blobs))) folders = self._build_folders_from_blobs(blobs, container, path, subdirs) # filter folders as per subdirs if not subdirs is True: # subdirs is set to an int value #console.print("filtering by subdirs=", subdirs) folders = [f for f in folders if f["level"] <= subdirs] dd["folders"] = folders return dd
def _get_container_path_target(self, path): ''' This helper function is the first step in converting all of our _xxx (internal) methods to accept as parameters: - container (optional - None or the Azure container name to use) - path (optional - None, or the container-relative path to the target store object) - wc_target (optional - the last node of the path containing the wildcard characters) The "container" should be None, or a simple container name. When set to None, the set of all container names is returned. The "path" should be None or a set of blob-folder-names separated by "/" chars. When set to None, the outer path of the container is targeted. The "wc_target" can be None, or a string containing the characters: *, ?, or **. ''' container = self.container if path: #console.print("path=", path) # path specified if path.startswith("blob-store://"): path = "/" + path[13:] if path.startswith("/"): # extract new value for container from left dir node path = path[1:] parts = path.split("/") container = parts[0] path = "/".join(parts[1:]) #console.print("container=", container, ", path=", path) elif self.base_path: # merge them container, base_path, path = self._process_parent_path( container, self.base_path, path) if base_path and path: path = base_path + "/" + path elif base_path: path = base_path else: # path not specified path = self.base_path #console.print("path=", path) if "*" in path or "?" in path: wc_target = os.path.basename(path) path = os.path.dirname(path) if "*" in path or "?" in path: errors.syntax_error( "wildcard characters are not allowed in directory part of store paths" ) else: wc_target = None #console.print("container=", container, ", path=", path, ", wc_target=", wc_target) return container, path, wc_target
def upload_file(self, fn, source_fn, progress_callback=None): #path = self._expand_path(fn) container, path, wc_target = self._get_container_path_target(fn) if wc_target: errors.syntax_error("wildcard characters not allowed in filename") return self.store.provider.create_blob_from_path( container, path, source_fn, progress_callback=progress_callback)
def syntax_error(self, msg): console.print(msg) self.show_current_command_syntax() if self.raise_syntax_exception: errors.syntax_error("syntax error") errors.syntax_error_exit()
def extract_dd_from_cmdline(self, cmd_line, option_prefix): ''' args: cmd_line: <file> <args> <options> option_prefix: usually "-" or "--" (denotes the start of an argument) processing: parse the cmd_line looking for options. for each option value that is a hp search directive, collect it into a data dictionary. return: the data dictionary of hp search directives ''' dd = {} parts = cmd_line.split(option_prefix) options = parts[1:] # skip over text before first option new_cmd_line = parts[0] if options: for option in options: found = False name = None if "=" in option: parts = option.split("=") if len(parts) == 2: name, value = parts elif " " in option: name, value = option.split(" ", 1) if name: name = name.strip() if "[" in name: errors.syntax_error( "option name and value must be separated by a space or '=': {}" .format(name)) # user may have added these thinking there were needed value = utils.remove_surrounding_quotes(value) if value.startswith("[") and value.endswith("]"): value = value[1:-1].strip() if not "$" in value: values = value.split(",") value = [ utils.get_python_value_from_text(val) for val in values ] dd[name] = hp_helper.parse_hp_dist(value) found = True if not found: new_cmd_line += " " + option_prefix + option return dd, new_cmd_line
def rerun(self, run_name, workspace, response): # NOTE: validate_run_name() call must be AFTER we call process_named_options() run_name, workspace = run_helper.parse_run_name(workspace, run_name) # extract "prompt" and "args" from cmdline cmdline, xt_cmdline, box_name, parent_name, node_index = self.get_info_for_run( workspace, run_name) #console.print("cmdline=", cmdline) prompt = "" if xt_cmdline: args = " " + xt_cmdline else: # legacy run; just use subset of xt cmd args = " xt " + cmdline console.print("edit/accept xt cmd for {}/{}".format( workspace, run_name)) if response: # allow user to supplement the cmd with automation if "$cmd" in response: response = response.replace("$cmd", args) console.print(response) else: response = pc_utils.input_with_default(prompt, args) # keep RERUN cmd simple by reusing parse_python_or_run_cmd() full_cmd = response.strip() #console.print(" new_cmd=" + full_cmd) if not full_cmd.startswith("xt "): errors.syntax_error( "command must start with 'xt ': {}".format(full_cmd)) # this temp dir cannot be removed immediately after job is submitted (TBD why) tmp_dir = file_utils.make_tmp_dir("rerun_cmd") job_id = self.store.get_job_id_of_run(workspace, run_name) capture.download_before_files(self.store, job_id, workspace, run_name, tmp_dir, silent=True, log_events=False) # move to tmp_dir so files get captured correctly prev_cwd = os.getcwd() os.chdir(tmp_dir) try: # recursive invoke of QFE parser to parse command (orginal + user additions) args = full_cmd.split(" ") args = args[1:] # drop the "xt" at beginning inner_dispatch(args, is_rerun=True) finally: # change back to original dir os.chdir(prev_cwd)
def get_job_number(job): if not is_job_id(job): errors.syntax_error("illegal job name, must start with 'job'") # allow for import prefixes part = job.split("_")[-1] if part.startswith("job"): part = part[3:] return int(part)
def get_runs_by_boxes_from_job(self, job_id): cancel_results_by_boxes = {} if not job_helper.is_job_id(str(job_id)): errors.syntax_error("not a valid job id: " + str(job_id)) #console.print("job_id=", job_id) text = self.store.read_job_info_file(job_id) job_info = json.loads(text) runs_by_box = job_info["runs_by_box"] return runs_by_boxes
def process_filter_list(self, filter_dict, filter_exp_list, report2filter): ''' used to filter records for following expressions of form: - <prop name> <relational operator> <value> special values: - $exists (does property exist) - $none (None, useful for matching properties with no value) - $empty (empty string) - $true (True) - $false (False) ''' for filter_exp in filter_exp_list: prop = filter_exp["prop"] op = filter_exp["op"] raw_value = filter_exp["value"] #print("prop=", prop, ", op=", op, ", raw_value=", raw_value) value = self.expand_special_symbols(raw_value) # translate name, if needed if prop in report2filter: prop = report2filter[prop] value = utils.make_numeric_if_possible(value) #console.print("prop, op, value=", prop, op, value) if op in ["=", "=="]: filter_dict[prop] = value elif op == "<": filter_dict[prop] = {"$lt": value} elif op == "<=": filter_dict[prop] = {"$lte": value} elif op == ">": filter_dict[prop] = {"$gt": value} elif op == ">=": filter_dict[prop] = {"$gte": value} elif op in ["!=", "<>"]: filter_dict[prop] = {"$ne": value} elif op == ":regex:": filter_dict[prop] = {"$regex": value} elif op == ":exists:": filter_dict[prop] = {"$exists": value} elif op == ":mongo:": # raw filter dict, but we need to translate quotes and load as json value = value.replace("`", "\"") value = json.loads(value) filter_dict[prop] = value else: errors.syntax_error( "filter operator not recognized/supported: {}".format(op))
def get_rightmost_run_num(run): if not run.startswith("run"): errors.syntax_error("Illegal run name, must start with 'run'") if "." in run: prefix, num = run.split(".") prefix += "." else: num = run[3:] prefix = "run" num = int(num) return num, prefix
def create_demo(self, destination, response, overwrite): ''' This command will removed the specified destination directory if it exists (prompting the user for approval). Specifying the current directory as the destination will produce an error. ''' # set up from_dir from_dir = file_utils.get_xtlib_dir() + "/demo_files" # set up dest_dir dest_dir = destination if not dest_dir: errors.syntax_error("An output directory must be specified") create = True console.print("creating demo files at: {}".format( os.path.abspath(dest_dir))) if os.path.exists(dest_dir): answer = pc_utils.input_response( "'{}' already exists; OK to delete? (y/n): ".format(dest_dir), response) if answer != "y": create = False if create: file_utils.ensure_dir_deleted(dest_dir) shutil.copytree(from_dir, dest_dir) #file_utils.copy_tree(from_dir, dest_dir) if not self.store.does_workspace_exist("xt-demo"): # import xt-demo workspace from archive file console.print( "importing xt-demo workspace (usually takes about 30 seconds)" ) impl_storage_api = ImplStorageApi(self.config, self.store) fn_archive = os.path.join(file_utils.get_xtlib_dir(), "demo_files", "xt-demo-archive.zip") impl_storage_api.import_workspace(fn_archive, "xt-demo", "xtd", overwrite=overwrite, show_output=False)
def parse_job_helper(store, job, job_list, actual_ws, validate=True, can_mix=True): if not can_mix: ws = validate_job_name_with_ws(store, job, validate) if actual_ws and actual_ws != ws and not can_mix: errors.syntax_error( "Cannot mix job_names from different workspaces for this command" ) else: ws = actual_ws job_list.append(job) return ws if ws else actual_ws
def validate_job_name_with_ws(store, job_name, validate): job_name = job_name.lower() if not is_job_id(job_name): return errors.syntax_error("Illegal job name: {}".format(job_name)) ws = store.get_job_workspace(job_name) if validate and not ws: errors.store_error("job '{}' does not exist".format(job_name)) return ws
def parse_run_list(store, workspace, runs, validate=True): run_names = [] actual_ws = None if runs: for run in runs: run = run.strip() run = correct_slash(run) if "/" in run: ws, run_name = run.split("/") if actual_ws and actual_ws != ws: errors.syntax_error( "Cannot mix run_names from different workspaces for this command" ) if not run.startswith("run"): errors.argument_error("run name", run) if "-" in run: # parse run range low, high = run.split("-") low, low_prefix = get_rightmost_run_num(low) high, high_prefix = get_rightmost_run_num(high) if low_prefix != high_prefix: errors.syntax_error( "for run name range, prefixes must match: {} vs. {}". format(low_prefix, high_prefix)) for rx in range(low, high + 1): rxx = low_prefix + str(rx) actual_ws = parse_run_helper(store, workspace, rxx, validate, actual_ws, run_names) else: actual_ws = parse_run_helper(store, workspace, run, validate, actual_ws, run_names) else: actual_ws = workspace #console.print("actual_ws=", actual_ws) return run_names, actual_ws
def range_runs(self, runs_dict, range): runs = list(runs_dict.values()) if range == "min-max": min_values = np.min(runs, axis=0) max_values = np.max(runs, axis=0) elif range == "std": means = np.mean(runs, axis=0) max_values = means + np.std(runs, axis=0) min_values = means - np.std(runs, axis=0) elif range == "error": from scipy import stats means = np.mean(runs, axis=0) max_values = means + stats.sem(runs, axis=0) min_values = means - stats.sem(runs, axis=0) else: errors.syntax_error("unrecognized range value: {}".format(range)) return min_values, max_values
def keysend(self, box): # syntax: xt keysend <box name> box_name = box if not box_name: errors.syntax_error("must specify a box name/address") info = box_information.get_box_addr(self.config, box_name, self.store) box_addr = info["box_addr"] if pc_utils.is_localhost(box_name, box_addr) or box_name == "azure-batch": errors.syntax_error( "must specify a remote box name or address (e.g., xt keysend [email protected]" ) console.print( "this will require 2 connections to the remote host, so you will be prompted for a password twice" ) status = self.core.keysend(box_name) if status: console.print("public key successfully sent.")
def calc_actual_layout(self, count, layout): if not "x" in layout: errors.syntax_error( "layout string must be of form RxC (R=# rows, C=# cols)") r, c = layout.split("x", 1) if r: r = int(r) c = int(c) if c else math.ceil(count / r) elif c: c = int(c) r = int(r) if r else math.ceil(count / c) full_count = r * c if full_count < count: errors.combo_error( "too many plots ({}) for layout cells ({})".format( count, full_count)) return r, c
def validate_run_name(store, ws, run_name, error_if_invalid=True, parse_only=False): run_name = correct_slash(run_name) if "/" in run_name: parts = run_name.split("/") if len(parts) != 2: errors.syntax_error("invalid format for run name: " + run_name) ws, run_name = parts run_name = run_name.lower() if not parse_only and not "*" in run_name: if not store.mongo.does_run_exist(ws, run_name): if error_if_invalid: errors.store_error( "run '{}' does not exist in workspace '{}'".format( run_name, ws)) else: return None, None, None return ws, run_name, ws + "/" + run_name
def monitor(self, name, escape=None, jupyter=None, log_name=None, node_index=None, sleep=1, workspace=None): if jupyter: return self.monitor_with_jupyter(workspace, name) if job_helper.is_job_id(name): return self.monitor_job_node(name, jupyter, sleep, node_index, log_name, escape) if name.startswith("run"): rr = run_helper.get_run_record(self.store, workspace, name) job_id = rr["job_id"] node_index = rr["node_index"] return self.monitor_job_node(job_id, jupyter, sleep, node_index, log_name, escape) errors.syntax_error("name must be a job or run name: {}".format(name))
def get_config_template(self, template): # load default config file as lines fn = get_default_config_template_path() default_text = file_utils.read_text_file(fn) default_lines = default_text.split("\n") # convert lines to sections dict sections = yaml.safe_load(default_text) if not template or template == "empty": # EMPTY hdr = \ "# local xt_config.yaml\n" + \ "# uncomment the below lines to start populating your config file\n\n" text = \ "#general:\n" + \ " #workspace: 'ws1'\n" + \ " #experiment: 'exper1'\n" elif template == "philly": # PHILLY hdr = "# local xt_config.yaml for Philly compute service\n\n" text = self.copy_and_merge_sections( sections, [ "external-services.philly", "external-services.philly-registry", "external-services.phoenixkeyvault", "external-services.phoenixmongodb", "external-services.phoenixregistry", "external-services.phoenixstorage", "xt-services", "compute-targets.philly", "setups.philly", "dockers.philly-pytorch", "general" ], update_keys={"xt-services.target": "philly"}) elif template == "batch": # BATCH hdr = "# local xt_config.yaml (for Azure Batch compute services)\n\n" text = self.copy_and_merge_sections( sections, [ "external-services.phoenixbatch", "external-services.phoenixkeyvault", "external-services.phoenixmongodb", "external-services.phoenixregistry", "external-services.phoenixstorage", "xt-services", "compute-targets.batch", "azure-batch-images", "general" ], update_keys={"xt-services.target": "batch"}) elif template == "aml": # AML hdr = "# local xt_config.yaml (for Azure ML compute service)\n\n" text = self.copy_and_merge_sections( sections, [ "external-services.phoenixaml", "external-services.phoenixkeyvault", "external-services.phoenixmongodb", "external-services.phoenixregistry", "external-services.phoenixstorage", "xt-services", "compute-targets.aml", "aml-options", "general" ], update_keys={"xt-services.target": "aml"}) elif template == "pool": # POOL hdr = "# local xt_config.yaml (for local machine and Pool compute service)\n\n" text = self.copy_and_merge_sections(sections, [ "external-services.phoenixkeyvault", "external-services.phoenixmongodb", "external-services.phoenixregistry", "external-services.phoenixstorage", "xt-services", "compute-targets.local", "compute-targets.local-docker", "boxes", "setups.local", "dockers.pytorch-xtlib", "dockers.pytorch-xtlib-local", "general" ]) elif template == "all": # ALL hdr = "# local xt_config.yaml (for all compute services)\n\n" text = "\n".join(default_lines) else: errors.syntax_error( "unrecognized --create value: {}".format(template)) return hdr + text
def validate_job_name(job_id): if job_id: safe_job_id = str(job_id) if not is_job_id(safe_job_id): errors.syntax_error("job id must start with 'job': " + safe_job_id)
def config_error(self, msg): full_msg = "Error in XT config file: {} (config: {})".format( msg, self.config_fn) errors.syntax_error(full_msg)
def plot_inner(self, ax, run_name, col, x_col, x_label, line_index, x_values, y_values, color, alpha, use_y_label, y2_values=None, err_values=None): import seaborn as sns from matplotlib.ticker import MaxNLocator if x_values is None: x_values = range(len(y_values)) else: ax.set_xlabel(x_label) console.detail("x_values=", x_values) console.detail("y_values=", y_values) console.detail("y2_values=", y2_values) num_y_ticks = 10 ax.get_yaxis().set_major_locator(MaxNLocator(num_y_ticks)) #color = self.colors[line_index % len(self.colors)] if use_y_label: line_title = self.legend_titles[line_index % len(self.legend_titles)] line_title = self.fixup_text(line_title, run_name, col) else: line_title = None cap_size = 5 is_range_plot = bool(y2_values is not None) # our default attributes kwargs = {"label": line_title, "color": color, "alpha": alpha} if not is_range_plot: kwargs["capsize"] = cap_size # let user override if self.plot_args and not is_range_plot: for name, value in self.plot_args.items(): value = utils.make_numeric_if_possible(value) kwargs[name] = value #cmap = self.get_seaborn_color_map("muted") if self.plot_type == "line": if is_range_plot: # RANGE plot ax.fill_between(x_values, y_values, y2_values, **kwargs) elif x_values is not None: # X/Y LINE plot trace = ax.errorbar(x_values, y_values, yerr=err_values, **kwargs) else: # LINE plot ax.errorbar(y_values, '-', label=line_title, yerr=err_values, **kwargs) else: # for now, we can get lots of milage out of line plot (errorbars, scatter, scatter+line) # so keep things simple and just support 1 type well errors.syntax_error("unknown plot type={}".format(self.plot_type)) if self.plot_titles: plot_title = self.plot_titles[line_index % len(self.plot_titles)] plot_title = self.fixup_text(plot_title, run_name, col) ax.set_title(plot_title) if self.show_legend: ax.legend() if self.legend_args: # pass legend args to legend object ax.legend(**self.legend_args)
def delete_file(self, filename): container, path, wc_target = self._get_container_path_target(filename) if wc_target: errors.syntax_error("wildcard not supported here: " + filename) self.store.provider.delete_blob(container, path) return True
def process_args(self, args): run_script = None parent_script = None run_cmd_from_script = None target_file = args["script"] target_args = args["script_args"] code_upload = args["code_upload"] # user may have wrong slashes for this OS target_file = file_utils.fix_slashes(target_file) if os.path.isabs(target_file): errors.syntax_error("path to app file must be specified with a relative path: {}".format(target_file)) is_rerun = "is_rerun" in args if is_rerun: # will be running from script dir, so remove any path to script file self.script_dir = os.path.dirname(target_file) target_file = os.path.basename(target_file) if target_file.endswith(".py"): # PYTHON target cmd_parts = ["python"] cmd_parts.append("-u") cmd_parts.append(target_file) else: cmd_parts = [target_file] if target_args: # split on unquoted spaces arg_parts = utils.cmd_split(target_args) cmd_parts += arg_parts if target_file == "docker": self.is_docker = True if not self.is_docker and code_upload and not os.path.exists(target_file): errors.env_error("script file not found: {}".format(target_file)) ps_path = args["parent_script"] if ps_path: parent_script = file_utils.read_text_file(ps_path, as_lines=True) if target_file.endswith(".bat") or target_file.endswith(".sh"): # a RUN SCRIPT was specified as the target run_script = file_utils.read_text_file(target_file, as_lines=True) run_cmd_from_script = scriptor.get_run_cmd_from_script(run_script) compute = args["target"] box_def = self.config.get("boxes", compute, suppress_warning=True) setup = utils.safe_value(box_def, "setup") compute_def = self.config.get_compute_def(compute) if compute_def: # must be defined in [compute-targets] compute_def = self.config.get_compute_def(compute) if not "service" in compute_def: errors.config_error("compute target '{}' must define a 'service' property".format(compute)) service = compute_def["service"] if service in ["local", "pool"]: # its a list of box names boxes = compute_def["boxes"] if len(boxes)==1 and boxes[0] == "localhost": pool = None box = "local" service_type = "pool" else: pool = compute box = None service_type = "pool" else: # it a set of compute service properties pool = compute box = None service_name = compute_def["service"] service_type = self.config.get_service_type(service_name) elif box_def: # translate single box name to a compute_def box = compute pool = None service_type = "pool" compute_def = {"service": service_type, "boxes": [box], setup: setup} else: errors.config_error("unknown target or box: {}".format(compute)) args["target"] = compute args["compute_def"] = compute_def args["service_type"] = service_type # for legacy code args["box"] = box args["pool"] = pool return service_type, cmd_parts, ps_path, parent_script, target_file, run_script, run_cmd_from_script, \ compute, compute_def
def _check_ws_name(self, ws_name): if not self.is_legal_workspace_name(ws_name): errors.syntax_error( "error: Illegal Azure workspace name (must be >= 3 alphanumeric chars, dashes OK, no space or underscore chars)" )
def generate_hparam_args(self, orig_cmd_parts, max_gen=None, search_type="grid"): ''' this is the main function for this class. it parses the specified argument name/values of 'orig_cmd_parts', converts each search list/range into a hyperparameter generator, and then generates a set of cmd_parts that comprise the hyperparameter search. returns: 'arg_sets' - a set of command line argument VALUES (one for each run) 'cmd_parts' - a template to be used to create a command line for the app (when applied to one of the arg sets) ''' cmd_parts = copy.copy(orig_cmd_parts) sweeps_text = "" # cmdline = " ".join(sys.argv[1:]) # console.print("ORIG cmdline: ", cmdline) hp_sets = {} last_part = None for i, part in enumerate(cmd_parts): part = part.replace('"', '') # remove double quoted sub-parts #console.print("part=", part) found = False dist_name = None ''' we support two basic forms of specifying hparam distributions to search: name=[values] name=@disttype(values) the name, the "=", and the right-hand expression can be seen all in one part, or in 2 parts (no "="), or in 3 parts. ''' if part == "=": # skip over optional "=" in its own part continue if "=[" in part and "]" in part: index = part.index("=[") name = part[:index] part = part[index + 2:] if part.endswith("]"): part = part[:-1].strip() cmd_parts[i] = cmd_parts[i][0:index + 1] + "{}" #console.print("part=", part, ", cmd_part=", cmd_parts[i]) found = True elif "=@" in part and "(" in part and part.endswith(")"): index = part.index("=@") name = part[:index] part = part[index + 2:] dist_name, value = part.split("(") if not dist_name in constants.distribution_types: errors.syntax_error("Unsupported distribution type: " + dist_name) part = value[:-1].strip() # remove ending paren cmd_parts[i] = "{}" found = True elif part.startswith("[") and part.endswith("]"): part = part[1:-1].strip() name = last_part cmd_parts[i] = "{}" found = True elif part.startswith("@") and "(" in part and part.endswith(")"): part = part[1:] # skip over "@" dist_name, value = part.split("(") if not dist_type in constants.distribution_types: errors.syntax_error("Unsupported distribution type: " + dist_name) name = last_part part = value[:-1].strip() # remove ending paren cmd_parts[i] = "{}" found = True if found: hp_set = self.parse_hp_set(part, dist_name, search_type) if not self.collect_only: text = self.hp_set_to_sweeps_line(name, hp_set) if text: sweeps_text += text + "\n" #console.print("hp_set=", hp_set) if hp_set: hp_sets[name] = hp_set last_part = part if sweeps_text: sweeps_text = "# sweeps.txt: generated from xt command line hyperparameter arguments\n" + \ "# note: this file is not sampled from directly\n" + \ sweeps_text # set the cycle len of each hp_set cycle_len = 1 for hp_set in hp_sets.values(): cycle_len = hp_set.set_cycle_len(cycle_len) # generate arg_sets from hp_sets arg_sets = [] #console.print("hp_sets=", hp_sets) using_hp = len(hp_sets) > 0 if hp_sets and not self.collect_only: if max_gen is None: max_gen = cycle_len else: max_gen = int(max_gen) while len(arg_sets) < max_gen: values = [] for name, hp_set in hp_sets.items(): value = hp_set.next() #console.print("generate value: ", value) values.append(value) arg_sets.append(values) return using_hp, hp_sets, arg_sets, cmd_parts, sweeps_text
def _get_creds_from_login(self, authentication, reason=None): # use normal Key Value from azure.keyvault.secrets import SecretClient if authentication == "auto": authentication = "browser" if pc_utils.has_gui() else "device-code" if authentication == "browser": console.print("authenticating with azure thru browser... ", flush=True, end="") from azure.identity import InteractiveBrowserCredential if self.azure_tenant_id is not None: credential = InteractiveBrowserCredential( tenant_id=self.azure_tenant_id) else: credential = InteractiveBrowserCredential() elif authentication == "device-code": # console.print("authenticating with azure thru device code... ", flush=True, end="") from azure.identity import DeviceCodeCredential from azure.identity._constants import AZURE_CLI_CLIENT_ID console.print( "using device-code authorization (Azure AD currently requires 2-4 authenications here)" ) if self.azure_tenant_id is not None: credential = DeviceCodeCredential( tenant_id=self.azure_tenant_id, client_id=AZURE_CLI_CLIENT_ID) else: credential = DeviceCodeCredential( client_id=AZURE_CLI_CLIENT_ID) else: errors.syntax_error( "unrecognized authentication type '{}'".format(authentication)) new_creds = True outer_token = credential.get_token() token = outer_token.token # expires = outer_token[1] # elapsed = expires - time.time() #print(" [new token expires in {:.2f} mins] ".format(elapsed/60), end="") # get keys from keyvault self.client = SecretClient(self.vault_url, credential=credential) key_text = self.get_secret_live("xt-keys") console.print("authenticated successfully", flush=True) #xt_client_cert = self.get_secret_live("xt-clientcert") xt_server_cert = self.get_secret_live("xt-servercert") # write all our creds to self.keys self.apply_creds(key_text) self.keys["xt_server_cert"] = xt_server_cert self.keys["object_id"] = self.get_me_graph_property(token, "id") # return creds as json string return json.dumps(self.keys)