def adjust_job_for_direct_run(self, job_id, job_runs, cmds, using_hp, experiment, service_type, snapshot_dir, search_style, args): # write 1st file to SNAPSHOT (first run's context file) fn_run_context = snapshot_dir + "/" + constants.FN_RUN_CONTEXT file_utils.ensure_dir_exists(file=fn_run_context) box_runs = job_runs[0] cfc = self.create_context_file_core(box_runs, 0, job_id, using_hp, app_info=None, exper_name=experiment, args=args) context_data = cfc["runs"][0] text = json.dumps(context_data, indent=4) with open(fn_run_context, "wt") as tfile: tfile.write(text) utils.copy_to_submit_logs(args, fn_run_context) return [fn_run_context]
def __init__(self, store, run_dir, mirror_dest, wildcard_path, grok_url, ws_name, run_name): # path = '.' # wildcard = "*.tfevents.*" self.run_dir = run_dir wildcard_path = os.path.expanduser(wildcard_path) wildcard_path = wildcard_path.replace("\\", "/") if not wildcard_path.startswith("/"): wildcard_path = os.path.join(run_dir, wildcard_path) if "*" in wildcard_path: path = os.path.dirname(wildcard_path) wildcard = os.path.basename(wildcard_path) else: path = wildcard_path wildcard = None path = file_utils.fix_slashes(path) console.print("MirrorWorker: path={}, wildcard={}".format( path, wildcard)) # in case program will create dir, but it hasn't yet been created file_utils.ensure_dir_exists(path) self.event_handler = MyHandler(store, mirror_dest, grok_url, ws_name, run_name, path, wildcard) self.observer = Observer() self.observer.schedule(self.event_handler, path, recursive=True)
def set_console_fn(self, console_fn): console_fn = os.path.expanduser(console_fn) self.console_fn = console_fn if os.path.exists(console_fn): os.remove(console_fn) file_utils.ensure_dir_exists(file=console_fn)
def poll_for_tensorboard_files(self, last_changed, blob_path, start_index, tb_path, run_name): # get all blobs in the run's output dir blobs = self.store.list_blobs(self.ws_name, blob_path, return_names=False) download_count = 0 #console.print("blob_names=", blob_names) for blob in blobs: # is this a tensorboard file? basename = os.path.basename(blob.name) if not basename.startswith("events.out.tfevents"): continue # get interesting part of blob's path (after run_name/) bn = blob.name[start_index:] modified = blob.properties.last_modified if not bn in last_changed or last_changed[bn] != modified: last_changed[bn] = modified if "{logdir}" in tb_path: # extract parent dir of blob test_train_node = os.path.basename( os.path.dirname(blob.name)) console.print("tb_path=", tb_path, ", test_train_node=", test_train_node, ", basename=", basename) # apply to remaining template tb_path_full = tb_path.format( **{"logdir": test_train_node}) #console.print("tb_path_full=", tb_path_full) local_fn = file_utils.path_join(tb_path_full, basename) else: local_fn = tb_path local_fn = os.path.join("logs", local_fn) console.print("our local_fn=", local_fn) # download the new/changed blob try: console.print("downloading bn={}, local_fn={}".format( bn, local_fn)) file_utils.ensure_dir_exists(file=local_fn) self.store.download_file_from_run(self.ws_name, run_name, bn, local_fn) download_count += 1 if self.print_progress: console.print("d", end="", flush=True) except BaseException as ex: logger.exception( "Error in download_file_from_run, from tensorboard_reader, ex={}" .format(ex)) return download_count
def generate(count, ext, subdir): texts = ["", "this is a test", "how about that?\nthis is a 2nd line\nthis is 3rd", "huh"] for i in range(count): fn = subdir + "test" + str(i) + ext file_utils.ensure_dir_exists(file=fn) with open(fn, "wt") as outfile: text = texts[i % 4] outfile.write(text)
def create_blob(self, container, blob_path, text, fail_if_exists=False): path = self._make_path(container, blob_path) file_utils.ensure_dir_exists(file=path) if fail_if_exists and os.path.exists(path): errors.service_error("blob already exists: " + blob_path) with open(path, "wt") as outfile: outfile.write(text) return True
def overwrite_default_config(): default_config_path = os.path.join(get_resource_dir(), constants.FN_DEFAULT_CONFIG) if is_default_config_present(): file_utils.zap_file(default_config_path) res_dir = get_resource_dir() file_utils.ensure_dir_exists(res_dir) fn_source = os.path.join(file_utils.get_xtlib_dir(), "helpers", constants.FN_DEFAULT_CONFIG) shutil.copyfile(fn_source, default_config_path)
def make_local_snapshot(self, snapshot_dir, code_dir, dest_name, omit_list): ''' keep code simple (and BEFORE upload fast): - always copy code dir to temp dir - if needed, copy xtlib subdir - later: if needed, add 2 extra controller files - later: zip the whole thing at once & upload ''' if dest_name and dest_name != ".": snapshot_dir += "/" + dest_name console.diag("before create local snapshot") # fixup slashes for good comparison snapshot_dir = os.path.realpath(snapshot_dir) # fully qualify path to code_dir for simpler code & more informative logging code_dir = os.path.realpath(code_dir) recursive = True if code_dir.endswith("**"): code_dir = code_dir[:-2] # drop the ** elif code_dir.endswith("*"): recursive = False # copy user's source dir (as per config file options) if True: omit_list = utils.parse_list_option_value(omit_list) # build list of files matching both criteria filenames = file_helper.get_filenames_from_include_lists(None, omit_list, recursive=recursive, from_dir=code_dir) file_utils.ensure_dir_exists(snapshot_dir) prefix_len = 2 if code_dir == "." else len(code_dir) copy_count = 0 # copy files recursively, preserving subdir names for fn in filenames: fn = os.path.realpath(fn) # fix slashes if fn.startswith(code_dir) and fn != code_dir: fn_dest = snapshot_dir + "/" + fn[prefix_len:] file_utils.ensure_dir_exists(file=fn_dest) shutil.copyfile(fn, fn_dest) else: shutil.copy(fn, snapshot_dir) copy_count += 1 #console.diag("after snapshot copy of {} files".format(copy_count)) else: shutil.copytree(code_dir, snapshot_dir) return snapshot_dir
def init_logging(fn, logger, title): fn_xt_info = os.path.expanduser(fn) file_utils.ensure_dir_exists(file=fn_xt_info) logging.basicConfig( format='%(asctime)s.%(msecs)03d, %(levelname)s, %(name)s: %(message)s', datefmt='%Y-%m-%d, %H:%M:%S', level=logging.INFO, filename=fn_xt_info) logger.info("---------------------------") logger.info("new {} started".format(title))
def create_blob_from_path(self, container, blob_path, source_fn, progress_callback=None): ''' NOTE: the file could be binary (don't assume it is text) ''' path = self._make_path(container, blob_path) file_utils.ensure_dir_exists(file=path) shutil.copyfile(source_fn, path) return True
def main(): # init environment config = xt_config.get_merged_config() file_utils.ensure_dir_exists(TEST_DIR) with DirChange(TEST_DIR): tester = StorageProviderTests() tester.test_impl("xtsandboxstorage") tester.test_impl("filestorage") file_utils.ensure_dir_deleted(TEST_DIR) return tester._assert_count
def zip_up_filenames(fn_zip, filenames, compress=True, remove_prefix_len=None): fn_zip = os.path.expanduser(fn_zip) file_utils.ensure_dir_exists(file=fn_zip) compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED #console.print("compression=", compression) with zipfile.ZipFile(fn_zip, "w", compression=compression) as zip: # writing each file one by one for fn in filenames: #console.print("zipping fn: " + fn) fn_dest = fn[remove_prefix_len:] if remove_prefix_len else fn zip.write(fn, arcname=fn_dest)
def export_run_storage_blobs(self, workspace, run_id, temp_store_path): # copy each storage file file_utils.ensure_dir_exists(temp_store_path) #fs = self.store.job_files(job_id, use_blobs=True) self.download("**", temp_store_path, share=None, workspace=workspace, experiment=None, job=None, run=run_id, feedback=False, snapshot=True, show_output=False)
def append_blob(self, container, blob_path, text, append_with_rewrite=False): ''' we ignore the *append_with_rewrite* request here, since it is an azure limitation workaround and not needed in a file-system provider. ''' path = self._make_path(container, blob_path) file_utils.ensure_dir_exists(file=path) with open(path, "at") as outfile: outfile.write(text) return True
def download_file(self, fn, dest_fn, progress_callback=None, use_snapshot=False): container, path, wc_target = self._get_container_path_target(fn) #console.print("container=", container, ", path=", path) # ensure blob exists ourselves so we can issue a friendly error if not self.store.provider.does_blob_exist(container, path): errors.store_error("Blob not found: container={}, path={}".format( container, path)) # ensure the directory of the dest_fn exists file_utils.ensure_dir_exists(file=dest_fn) if use_snapshot: # create temp. snapshot if progress_callback: progress_callback(status="creating-snapshot") props = self.store.provider.snapshot_blob(container, path) snapshot_id = props.snapshot # download the snapshot if progress_callback: progress_callback(status="downloading-snapshot") text = self.store.provider.get_blob_to_path( container, path, dest_fn, snapshot=snapshot_id, progress_callback=progress_callback) # delete the snapshot if progress_callback: progress_callback(status="deleting-snapshot") self.store.provider.delete_blob(container, path, snapshot=snapshot_id) if progress_callback: progress_callback(status="deleted-snapshot") else: # normal download text = self.store.provider.get_blob_to_path( container, path, dest_fn, progress_callback=progress_callback) return text
def test_cmd(self, i, cmd, logs_dir, fake): self.cmd_count += 1 if not fake: cmd = cmd.replace("--fake-submit=True", "--fake-submit=False") print("-------------------------------------") print("runTests: testing (# {}, errors: {}/{}): {}".format(i, self.file_compare_errors, self.file_count, cmd)) #console.set_level("diagnostics") file_utils.ensure_dir_exists(logs_dir) xt_cmds.main(cmd) if self.compare: self.compare_submit_logs(logs_dir)
def set_secret(name, value): name = correct_name(name) console.diag("set_secret: name={}, value={}".format(name, value)) file_utils.ensure_dir_exists(file=FN_SECRETS) secrets = {} # read existing secrets, if any if os.path.exists(FN_SECRETS): text = file_utils.read_text_file(FN_SECRETS) secrets = json.loads(text) secrets[name] = value # write updates secrets text = json.dumps(secrets) file_utils.write_text_file(FN_SECRETS, text)
def get_default_config_path(): ''' always call this function to find the "default_config.yaml" file. calling this ensures that the file has been copied from its package location. ''' res_dir = get_resource_dir() fn = os.path.join(res_dir, constants.FN_DEFAULT_CONFIG) if not os.path.exists(fn): # copy it from its helpers dir in the installed package (or dev directory) file_utils.ensure_dir_exists(res_dir) fn_source = os.path.join(file_utils.get_xtlib_dir(), "helpers", constants.FN_DEFAULT_CONFIG) shutil.copyfile(fn_source, fn) # make file readonly file_utils.make_readonly(fn) return fn
def _start_xt_cache_server(self): import subprocess DETACHED_PROCESS = 0x00000008 CREATE_NO_WINDOW = 0x08000000 # launch in visible window for debugging MAKE_SERVER_VISIBLE = False xtlib_dir = os.path.dirname(__file__) fn_script = "{}/cache_server.py".format(xtlib_dir) fn_log = os.path.expanduser("~/.xt/tmp/cache_server.log") file_utils.ensure_dir_exists(file=fn_log) if MAKE_SERVER_VISIBLE: #subprocess.Popen(parts, cwd=".", creationflags=DETACHED_PROCESS) cmd = "start python " + fn_script os.system(cmd) elif pc_utils.is_windows(): # run detached, hidden for WINDOWS parts = ["cmd", "/c", "python", fn_script] flags = CREATE_NO_WINDOW with open(fn_log, 'w') as output: subprocess.Popen(parts, cwd=".", creationflags=flags, stdout=output, stderr=subprocess.STDOUT) else: # run detached, hidden for LINUX parts = ["python", fn_script] with open(fn_log, 'w') as output: subprocess.Popen(parts, cwd=".", stdout=output, stderr=subprocess.STDOUT) # give it time to start-up and receive commands time.sleep(2)
def init_dirs(self, args): # set mnt_output_dir (using environment variable setting from xt) mnt_output_dir = os.getenv("XT_OUTPUT_MNT", "output") mnt_output_dir = os.path.expanduser(mnt_output_dir) file_utils.ensure_dir_exists(mnt_output_dir) print("writing mnt_output to: " + mnt_output_dir) # set local_output_dir (using environment variable setting from xt) local_output_dir = "output" file_utils.ensure_dir_exists(local_output_dir) print("writing local_output to: " + local_output_dir) # set data_dir (allowing overridden by environment variable) data_dir = os.getenv("XT_DATA_DIR", args.data) data_dir = os.path.expanduser(data_dir) file_utils.ensure_dir_exists(data_dir) print("getting data from: " + data_dir) fn_test = data_dir + "/MNIST/processed/test.pt" exists = os.path.exists(fn_test) print("fn_test={}, exists={}".format(fn_test, exists)) fn_train = data_dir + "/MNIST/processed/training.pt" exists = os.path.exists(fn_train) print("fn_train={}, exists={}".format(fn_train, exists)) if args.download_only: print("miniMnist (ensuring data is downloaded)") self.get_dataget_dataset(data_dir, True, True) self.get_dataset(data_dir, False, True) return mnt_output_dir, local_output_dir, data_dir
def generate_help(self, dest_dir): file_utils.ensure_dir_exists(dest_dir) cmds = qfe.get_commands() count = 0 for cmd in cmds: if cmd["hidden"]: continue cmd_name = cmd["name"].replace(" ", "_") fn = "{}/{}.rst".format(dest_dir, cmd_name) text = self.generate_help_cmd(cmd) # write text to .RST file with open(fn, "wt") as outfile: outfile.write(text) count += 1 console.print("{} files generated to: {}".format(count, dest_dir))
def write_script_file(script_lines, fn, for_windows): ''' args: - script_lines: a list of strings (NOT newline terminated) - fn: path of file to create - for_windows: if True, lines will be written to end with CR + NEWLINE return: - the update filename (with "~" expanded) ''' fn = os.path.expanduser(fn) file_utils.ensure_dir_exists(file=fn) # set the newline joiner according to the target OS newline = "\r\n" if for_windows else "\n" text = newline.join(script_lines) # specify newline="" here to prevent open() from messing with our newlines with open(fn, "wt", newline="") as outfile: if not for_windows: # remove any rouge CR characters text = text.replace("\r", "") outfile.write(text) if not for_windows: # ensure no CR characters are found with open(fn, "rb") as infile: byte_buff = infile.read() if 13 in byte_buff: console.print("WARNING: write_script_file failed to remove all CR chars") #console.print("for_windows=", for_windows, "newline=", newline, ", script_lines=", script_lines) # test_text = file_utils.read_text_file(fn) # console.print("test_text=", test_text) return fn
def __init__(self, wildcard_path): # path = '.' # wildcard = "*.tfevents.*" wildcard_path = os.path.expanduser(wildcard_path) wildcard_path = wildcard_path.replace("\\", "/") if "*" in wildcard_path: path = os.path.dirname(wildcard_path) wildcard = os.path.basename(wildcard_path) else: path = wildcard_path wildcard = None path = file_utils.fix_slashes(path) #console.print("WatchWorker: path={}, wildcard={}".format(path, wildcard)) # in case program will create dir, but it hasn't yet been created file_utils.ensure_dir_exists(path) self.event_handler = MyHandler() self.observer = Observer() #console.print("WATCHING: " + path) self.observer.schedule(self.event_handler, path, recursive=True)
def _download_files(self, container, path, wc_target, dest_folder): #console.print("ws_name=", ws_name, ", ws_wildcard=", ws_wildcard) files_copied = [] names = self._list_wild_blobs(container, path, wc_target, include_folder_names=True) console.diag("_download_files: names=", names) blob_dir = path bd_index = 1 + len(blob_dir) # add for for trailing slash #console.print("blob_dir=", blob_dir, ", bd_index=", bd_index) for bn in names: base_bn = bn[bd_index:] dest_fn = dest_folder + "/" + base_bn console.detail("_download_files: bn=", bn, ", dest_fn=", dest_fn) file_utils.ensure_dir_exists(file=dest_fn) self.provider.get_blob_to_path(container, bn, dest_fn) files_copied.append(dest_fn) return files_copied
def adjust_job_for_controller_run(self, job_id, job_runs, cmds, using_hp, experiment, service_type, snapshot_dir, search_style, args): ''' submit direct job: - backend.commands: some internal prep cmds and the run command(s) specified by the user - backend.env_var: these are set to pass a small subset of the context for the runs - backend.source_files: the user's source files, the run's context file submit controller job: - backend.commands: some internal prep cmds and a command line to run the XT controller - backend.env_var: not used - backend.source_files: the user's source files, controller MULTI_RUN_CONTEXT file, controller script file ''' # for EACH NODE, collect and adjust runs context_by_nodes = {} for i, box_runs in enumerate(job_runs): node_context = self.create_context_file_core(box_runs, i, job_id, using_hp, app_info=None, exper_name=experiment, args=args) node_id = "node" + str(i) context_by_nodes[node_id] = node_context new_box_runs = self.adjust_box_runs_for_controller(i, box_runs) job_runs[i] = new_box_runs # write 1st file to SNAPSHOT (MRC file) fn_context = snapshot_dir + "/" + constants.FN_MULTI_RUN_CONTEXT file_utils.ensure_dir_exists(file=fn_context) mrc_data = { "context_by_nodes": context_by_nodes, "cmds": cmds, "search_style": search_style } text = json.dumps(mrc_data, indent=4) with open(fn_context, "wt") as tfile: tfile.write(text) # also write the MRC file to the job store (to support wrapup of runs after job is cancelled) self.store.create_job_file(job_id, constants.FN_MULTI_RUN_CONTEXT, text) # write 2rd file to SNAPSHOT fn_script = snapshot_dir + "/" + constants.PY_RUN_CONTROLLER is_aml = (service_type == "aml") with open(fn_script, "wt") as outfile: external_controller_port = constants.CONTROLLER_PORT text = "" text += "import sys\n" text += "sys.path.insert(0, '.') # support for --xtlib-upload \n" text += "from xtlib.controller import run\n" text += "run(multi_run_context_fn='{}', port={}, is_aml={})\n".format( constants.FN_MULTI_RUN_CONTEXT, external_controller_port, is_aml) outfile.write(text) utils.copy_to_submit_logs(args, fn_context) utils.copy_to_submit_logs(args, fn_script) return [fn_context, fn_script]
def download(self, store_path, local_path, share, workspace, experiment, job, run, feedback, snapshot, show_output=True): use_blobs = True use_multi = True # default until we test if store_path exists as a file/blob download_count = 0 fs = self.create_file_accessor(use_blobs, share, workspace, experiment, job, run) # test for existance of store_path as a blob/file if not "*" in store_path and not "?" in store_path: if fs.does_file_exist(store_path): use_multi = False if local_path: # exapnd ~/ in front of local path local_path = os.path.expanduser(local_path) else: # path not specified for local if use_multi: local_path = "." else: local_path = "./" + os.path.basename(store_path) uri = fs.get_uri(store_path) # default store folder to recursive if use_multi and not "*" in store_path and not "?" in store_path: store_path += "/**" use_snapshot = snapshot feedback_progress = FeedbackProgress(feedback, show_output) progress_callback = feedback_progress.progress if feedback else None if use_multi: # download MULTI blobs/files what = "blobs" if use_blobs else "files" single_what = what[0:-1] if show_output: console.print("collecting {} names from: {}...".format( single_what, uri), end="") _, blob_names = fs.get_filenames(store_path, full_paths=False) if show_output: console.print() if len(blob_names) == 0: console.print("no matching {} found in: {}".format(what, uri)) return 0 elif len(blob_names) == 1: what = "blob" if use_blobs else "file" if show_output: console.print("\ndownloading {} {}...:".format( len(blob_names), what)) file_utils.ensure_dir_exists(local_path) max_name_len = max( [len(local_path + "/" + name) for name in blob_names]) name_width = 1 + max_name_len #console.print("max_name_len=", max_name_len, ", name_width=", name_width) for f, bn in enumerate(blob_names): dest_fn = file_utils.fix_slashes(local_path + "/" + bn) if show_output: file_msg = "file {}/{}".format(1 + f, len(blob_names)) console.print(" {2:}: {1:<{0:}} ".format( name_width, dest_fn + ":", file_msg), end="", flush=True) feedback_progress.start() full_bn = uri + "/" + bn if uri else bn fs.download_file(full_bn, dest_fn, progress_callback=progress_callback, use_snapshot=use_snapshot) feedback_progress.end() download_count += 1 else: # download SINGLE blobs/files what = "blob" if use_blobs else "file" if not fs.does_file_exist(store_path): errors.store_error("{} not found: {}".format(what, uri)) local_path = file_utils.fix_slashes(local_path) if show_output: console.print("\nfrom {}, downloading {}:".format(uri, what)) console.print(" {}: ".format(local_path), end="", flush=True) feedback_progress.start() fs.download_file(store_path, local_path, progress_callback=progress_callback, use_snapshot=use_snapshot) feedback_progress.end() download_count += 1 return download_count
def __init__(self, storage_creds): self.path = os.path.expanduser(storage_creds["path"]) self.retry = None # create directory, if needed file_utils.ensure_dir_exists(self.path)
def main(): started = time.time() #print("args=", sys.argv) args = parse_args() run, model, device, train_loader, test_loader, train_writer, test_writer =\ init_stuff(args) start_epoch = 1 # log hyperparameters to xt hp_dict = { "seed": args.seed, "batch-size": args.batch_size, "epochs": args.epochs, "lr": args.lr, "momentum": args.momentum, "channels1": args.channels1, "channels2": args.channels2, "kernel_size": args.kernel_size, "mlp-units": args.mlp_units, "weight-decay": args.weight_decay, "optimizer": args.optimizer, "mid-conv": args.mid_conv, "gpu": args.gpu, "parallel": args.parallel, "distributed": args.distributed } if run: run.log_hparams(hp_dict) # console.print hyperparameters print("hyperparameters:", hp_dict) print() # see if we are resuming a preempted run if run and run.resume_name: print("resuming from run=", run.resume_name) dd = run.get_checkpoint(fn_checkpoint) if dd and dd["epoch"]: model.load_state_dict(torch.load(fn_checkpoint)) start_epoch = 1 + dd["epoch"] if args.optimizer == "sgd": #print("using SGD optimizer") optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: #print("using Adam optimizer") optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.distributed: optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) # Broadcast parameters from rank 0 to all other processes. hvd.broadcast_parameters(model.state_dict(), root_rank=0) checkpoint_freq = 0 checkpoint_units = "" last_checkpoint = time.time() checkpoint_count = 0 # force a ML app error to kill the app #x = foo/bar # parse checkpoint arg #print("args.checkpoint=", args.checkpoint, ", type(args.checkpoint)", type(args.checkpoint)) if False: # args.checkpoint: if type(args.checkpoint) in ["int", "float"]: checkpoint_freq = int(args.checkpoint) checkpoint_units = "epochs" elif isinstance(args.checkpoint, str): parts = args.checkpoint.split(' ') if len(parts) == 2: checkpoint_freq, checkpoint_units = parts checkpoint_freq = float(checkpoint_freq) checkpoint_units = checkpoint_units.strip().lower() else: checkpoint_freq = float(args.checkpoint) checkpoint_units = "epochs" model_dir = os.getenv("XT_MODEL_DIR", "models/miniMnist") fn_model = model_dir + "/mnist_cnn.pt" if args.eval_model: # load model and evaluate it print("loading existing MODEL and evaluating it, fn=", fn_model) model.load_state_dict(torch.load(fn_model)) train_test_loop(run, model, device, train_loader, test_loader, optimizer, start_epoch, checkpoint_freq, train_writer, test_writer, test_only=True, args=args) else: train_test_loop(run, model, device, train_loader, test_loader, optimizer, start_epoch, checkpoint_freq, train_writer, test_writer, test_only=False, args=args) if (args.save_model): file_utils.ensure_dir_exists(model_dir) save_model(model, fn_model) if args.clear_checkpoint_at_end: if checkpoint_freq and run and run.store: run.clear_checkpoint() # console.print speed test at end # started = time.time() # text = "" # for i in range(100): # text += "$$: this is to test # {} out console.print speed output. it seems to be much faster on the CONTROLLER console that on the ATTACHED console. ".format(i+1) # print(text) # elapsed = time.time() - started # print("console.print test ended (elapsed: {:2f} secs)".format(elapsed)) if train_writer: train_writer.close() test_writer.close() if run: # ensure we log end of run for AML run.close() elapsed = time.time() - started print("\n--- miniMnist elapsed: {:.0f} secs ---".format(elapsed))
def run(self): print("args=", sys.argv) self.args = parse_cmdline_args() args = self.args fn_runset = "runset.yaml" if os.path.exists(fn_runset): self.apply_runset_file(args, fn_runset) model, device, mnt_output_dir, local_output_dir = self.init_stuff() start_epoch = 1 run = self.run if args.raise_error: #errors.internal_error("Raising an intentional error") # try a different type of error abc.foo = 1 # log hyperparameters to xt if run: hp_dict = { "seed": args.seed, "batch-size": args.batch_size, "epochs": args.epochs, "lr": args.lr, "momentum": args.momentum, "channels1": args.channels1, "channels2": args.channels2, "kernel_size": args.kernel_size, "mlp-units": args.mlp_units, "weight-decay": args.weight_decay, "optimizer": args.optimizer, "mid-conv": args.mid_conv, "gpu": args.gpu, "log-interval": args.log_interval } run.log_hparams(hp_dict) if args.cuda: # if on linux, show GPU info if os.name != "nt": os.system("nvidia-smi") # print hyperparameters print("hyperparameters:", hp_dict) print() # see if we are resuming a preempted run if run and run.resume_name: print("resuming from run=", run.resume_name) dd = run.get_checkpoint(fn_checkpoint) if dd and dd["epoch"]: model.load_state_dict(torch.load(fn_checkpoint)) start_epoch = 1 + dd["epoch"] if args.optimizer == "sgd": #print("using SGD optimizer") optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: #print("using Adam optimizer") optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.distributed: optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters()) # Broadcast parameters from rank 0 to all other processes. hvd.broadcast_parameters(model.state_dict(), root_rank=0) checkpoint_freq = 0 checkpoint_units = "" last_checkpoint = time.time() checkpoint_count = 0 # force a ML app error to kill the app #x = foo/bar # parse checkpoint arg #print("args.checkpoint=", args.checkpoint, ", type(args.checkpoint)", type(args.checkpoint)) if False: # args.checkpoint: if type(args.checkpoint) in ["int", "float"]: checkpoint_freq = int(args.checkpoint) checkpoint_units = "epochs" elif isinstance(args.checkpoint, str): parts = args.checkpoint.split(' ') if len(parts) == 2: checkpoint_freq, checkpoint_units = parts checkpoint_freq = float(checkpoint_freq) checkpoint_units = checkpoint_units.strip().lower() else: checkpoint_freq = float(args.checkpoint) checkpoint_units = "epochs" model_dir = os.getenv("XT_MODEL_DIR", "models/miniMnist") fn_model = model_dir + "/mnist_cnn.pt" self.fn_text_log = mnt_output_dir + "/text_log.txt" if args.eval_model: # load model and evaluate it print("loading existing MODEL and evaluating it, fn=", fn_model) exists = os.path.exists(fn_model) print("model exists=", exists) model.load_state_dict(torch.load(fn_model)) print("model loaded!") # just test model self.test_model_and_log_metrics(run, model, device, epoch=1, args=args) else: self.train_test_loop(run, model, device, optimizer, 1, checkpoint_freq, args=args) if (args.save_model): file_utils.ensure_dir_exists(model_dir) self.save_model(model, fn_model) # always save a copy of model in the AFTER FILES self.save_model(model, "output/mnist_cnn.pt") if args.clear_checkpoint_at_end: if checkpoint_freq and run and run.store: run.clear_checkpoint() # create a file to be captured in OUTPUT FILES fn_app_log = os.path.join(local_output_dir, "miniMnist_log.txt") with open(fn_app_log, "wt") as outfile: outfile.write("This is a log for miniMnist app\n") outfile.write("miniMnist app completed\n") # create a file to be ignored in OUTPUT FILES fn_app_log = os.path.join(local_output_dir, "test.junk") with open(fn_app_log, "wt") as outfile: outfile.write( "This is a file that should be omitted from AFTER upload\n") outfile.write("end of junk file\n") if run: # ensure we close all logging run.close()