Exemple #1
0
    def adjust_job_for_direct_run(self, job_id, job_runs, cmds, using_hp,
                                  experiment, service_type, snapshot_dir,
                                  search_style, args):

        # write 1st file to SNAPSHOT (first run's context file)
        fn_run_context = snapshot_dir + "/" + constants.FN_RUN_CONTEXT
        file_utils.ensure_dir_exists(file=fn_run_context)

        box_runs = job_runs[0]
        cfc = self.create_context_file_core(box_runs,
                                            0,
                                            job_id,
                                            using_hp,
                                            app_info=None,
                                            exper_name=experiment,
                                            args=args)
        context_data = cfc["runs"][0]

        text = json.dumps(context_data, indent=4)
        with open(fn_run_context, "wt") as tfile:
            tfile.write(text)

        utils.copy_to_submit_logs(args, fn_run_context)

        return [fn_run_context]
Exemple #2
0
    def __init__(self, store, run_dir, mirror_dest, wildcard_path, grok_url,
                 ws_name, run_name):
        # path = '.'
        # wildcard = "*.tfevents.*"

        self.run_dir = run_dir

        wildcard_path = os.path.expanduser(wildcard_path)
        wildcard_path = wildcard_path.replace("\\", "/")

        if not wildcard_path.startswith("/"):
            wildcard_path = os.path.join(run_dir, wildcard_path)

        if "*" in wildcard_path:
            path = os.path.dirname(wildcard_path)
            wildcard = os.path.basename(wildcard_path)
        else:
            path = wildcard_path
            wildcard = None

        path = file_utils.fix_slashes(path)
        console.print("MirrorWorker: path={}, wildcard={}".format(
            path, wildcard))

        # in case program will create dir, but it hasn't yet been created
        file_utils.ensure_dir_exists(path)

        self.event_handler = MyHandler(store, mirror_dest, grok_url, ws_name,
                                       run_name, path, wildcard)
        self.observer = Observer()
        self.observer.schedule(self.event_handler, path, recursive=True)
    def set_console_fn(self, console_fn):
        console_fn = os.path.expanduser(console_fn)
        self.console_fn = console_fn

        if os.path.exists(console_fn):
            os.remove(console_fn)

        file_utils.ensure_dir_exists(file=console_fn)
    def poll_for_tensorboard_files(self, last_changed, blob_path, start_index,
                                   tb_path, run_name):
        # get all blobs in the run's output dir
        blobs = self.store.list_blobs(self.ws_name,
                                      blob_path,
                                      return_names=False)
        download_count = 0

        #console.print("blob_names=", blob_names)
        for blob in blobs:
            # is this a tensorboard file?
            basename = os.path.basename(blob.name)
            if not basename.startswith("events.out.tfevents"):
                continue

            # get interesting part of blob's path (after run_name/)
            bn = blob.name[start_index:]
            modified = blob.properties.last_modified

            if not bn in last_changed or last_changed[bn] != modified:
                last_changed[bn] = modified

                if "{logdir}" in tb_path:

                    # extract parent dir of blob
                    test_train_node = os.path.basename(
                        os.path.dirname(blob.name))
                    console.print("tb_path=", tb_path, ", test_train_node=",
                                  test_train_node, ", basename=", basename)

                    # apply to remaining template
                    tb_path_full = tb_path.format(
                        **{"logdir": test_train_node})
                    #console.print("tb_path_full=", tb_path_full)
                    local_fn = file_utils.path_join(tb_path_full, basename)
                else:
                    local_fn = tb_path

                local_fn = os.path.join("logs", local_fn)
                console.print("our local_fn=", local_fn)

                # download the new/changed blob
                try:
                    console.print("downloading bn={}, local_fn={}".format(
                        bn, local_fn))
                    file_utils.ensure_dir_exists(file=local_fn)
                    self.store.download_file_from_run(self.ws_name, run_name,
                                                      bn, local_fn)
                    download_count += 1

                    if self.print_progress:
                        console.print("d", end="", flush=True)
                except BaseException as ex:
                    logger.exception(
                        "Error in download_file_from_run, from tensorboard_reader, ex={}"
                        .format(ex))

        return download_count
def generate(count, ext, subdir):
    texts = ["", "this is a test", "how about that?\nthis is a 2nd line\nthis is 3rd", "huh"]

    for i in range(count):
        fn = subdir + "test" + str(i) + ext
        file_utils.ensure_dir_exists(file=fn)

        with open(fn, "wt") as outfile:
            text = texts[i % 4]
            outfile.write(text)
    def create_blob(self, container, blob_path, text, fail_if_exists=False):
        path = self._make_path(container, blob_path)
        file_utils.ensure_dir_exists(file=path)

        if fail_if_exists and os.path.exists(path):
            errors.service_error("blob already exists: " + blob_path)

        with open(path, "wt") as outfile:
            outfile.write(text)
        return True
def overwrite_default_config():
    default_config_path = os.path.join(get_resource_dir(),
                                       constants.FN_DEFAULT_CONFIG)
    if is_default_config_present():
        file_utils.zap_file(default_config_path)

    res_dir = get_resource_dir()
    file_utils.ensure_dir_exists(res_dir)
    fn_source = os.path.join(file_utils.get_xtlib_dir(), "helpers",
                             constants.FN_DEFAULT_CONFIG)
    shutil.copyfile(fn_source, default_config_path)
Exemple #8
0
    def make_local_snapshot(self, snapshot_dir, code_dir, dest_name, omit_list):
        '''
        keep code simple (and BEFORE upload fast):
            - always copy code dir to temp dir
            - if needed, copy xtlib subdir
            - later: if needed, add 2 extra controller files
            - later: zip the whole thing at once & upload 
        '''
        if dest_name and dest_name != ".":
            snapshot_dir += "/" + dest_name

        console.diag("before create local snapshot")

        # fixup slashes for good comparison
        snapshot_dir = os.path.realpath(snapshot_dir)

        # fully qualify path to code_dir for simpler code & more informative logging
        code_dir = os.path.realpath(code_dir)

        recursive = True

        if code_dir.endswith("**"):
            code_dir = code_dir[:-2]   # drop the **
        elif code_dir.endswith("*"):
            recursive = False

        # copy user's source dir (as per config file options)
        if True:    
            omit_list = utils.parse_list_option_value(omit_list)

            # build list of files matching both criteria
            filenames = file_helper.get_filenames_from_include_lists(None, omit_list, recursive=recursive, from_dir=code_dir)

            file_utils.ensure_dir_exists(snapshot_dir)
            prefix_len = 2 if code_dir == "." else len(code_dir)
            copy_count = 0

            # copy files recursively, preserving subdir names
            for fn in filenames:
                fn = os.path.realpath(fn)           # fix slashes

                if fn.startswith(code_dir) and fn != code_dir:
                    fn_dest = snapshot_dir + "/" + fn[prefix_len:]
                    file_utils.ensure_dir_exists(file=fn_dest)
                    shutil.copyfile(fn, fn_dest)
                else:
                    shutil.copy(fn, snapshot_dir)
                copy_count += 1

            #console.diag("after snapshot copy of {} files".format(copy_count))
        else:
            shutil.copytree(code_dir, snapshot_dir)  
            
        return snapshot_dir
Exemple #9
0
def init_logging(fn, logger, title):
    fn_xt_info = os.path.expanduser(fn)
    file_utils.ensure_dir_exists(file=fn_xt_info)

    logging.basicConfig(
        format='%(asctime)s.%(msecs)03d, %(levelname)s, %(name)s: %(message)s',
        datefmt='%Y-%m-%d, %H:%M:%S',
        level=logging.INFO,
        filename=fn_xt_info)

    logger.info("---------------------------")
    logger.info("new {} started".format(title))
    def create_blob_from_path(self,
                              container,
                              blob_path,
                              source_fn,
                              progress_callback=None):
        '''
        NOTE: the file could be binary (don't assume it is text)
        '''
        path = self._make_path(container, blob_path)
        file_utils.ensure_dir_exists(file=path)

        shutil.copyfile(source_fn, path)
        return True
def main():
    # init environment
    config = xt_config.get_merged_config()
    file_utils.ensure_dir_exists(TEST_DIR)

    with DirChange(TEST_DIR):
        tester = StorageProviderTests()

        tester.test_impl("xtsandboxstorage")
        tester.test_impl("filestorage")
    
    file_utils.ensure_dir_deleted(TEST_DIR)
    return tester._assert_count
def zip_up_filenames(fn_zip, filenames, compress=True, remove_prefix_len=None):
    fn_zip = os.path.expanduser(fn_zip)
    file_utils.ensure_dir_exists(file=fn_zip)

    compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED
    #console.print("compression=", compression)

    with zipfile.ZipFile(fn_zip, "w", compression=compression) as zip:
        # writing each file one by one
        for fn in filenames:
            #console.print("zipping fn: " + fn)
            fn_dest = fn[remove_prefix_len:] if remove_prefix_len else fn
            zip.write(fn, arcname=fn_dest)
Exemple #13
0
    def export_run_storage_blobs(self, workspace, run_id, temp_store_path):
        # copy each storage file
        file_utils.ensure_dir_exists(temp_store_path)

        #fs = self.store.job_files(job_id, use_blobs=True)
        self.download("**",
                      temp_store_path,
                      share=None,
                      workspace=workspace,
                      experiment=None,
                      job=None,
                      run=run_id,
                      feedback=False,
                      snapshot=True,
                      show_output=False)
    def append_blob(self,
                    container,
                    blob_path,
                    text,
                    append_with_rewrite=False):
        '''
        we ignore the *append_with_rewrite* request here, since it is an azure limitation workaround
        and not needed in a file-system provider.
        '''
        path = self._make_path(container, blob_path)
        file_utils.ensure_dir_exists(file=path)

        with open(path, "at") as outfile:
            outfile.write(text)
        return True
    def download_file(self,
                      fn,
                      dest_fn,
                      progress_callback=None,
                      use_snapshot=False):
        container, path, wc_target = self._get_container_path_target(fn)
        #console.print("container=", container, ", path=", path)

        # ensure blob exists ourselves so we can issue a friendly error
        if not self.store.provider.does_blob_exist(container, path):
            errors.store_error("Blob not found: container={}, path={}".format(
                container, path))

        # ensure the directory of the dest_fn exists
        file_utils.ensure_dir_exists(file=dest_fn)

        if use_snapshot:
            # create temp. snapshot
            if progress_callback:
                progress_callback(status="creating-snapshot")
            props = self.store.provider.snapshot_blob(container, path)
            snapshot_id = props.snapshot

            # download the snapshot
            if progress_callback:
                progress_callback(status="downloading-snapshot")
            text = self.store.provider.get_blob_to_path(
                container,
                path,
                dest_fn,
                snapshot=snapshot_id,
                progress_callback=progress_callback)

            # delete the snapshot
            if progress_callback:
                progress_callback(status="deleting-snapshot")
            self.store.provider.delete_blob(container,
                                            path,
                                            snapshot=snapshot_id)

            if progress_callback:
                progress_callback(status="deleted-snapshot")
        else:
            # normal download
            text = self.store.provider.get_blob_to_path(
                container, path, dest_fn, progress_callback=progress_callback)

        return text
Exemple #16
0
    def test_cmd(self, i, cmd, logs_dir, fake):
        self.cmd_count += 1        

        if not fake:
            cmd = cmd.replace("--fake-submit=True", "--fake-submit=False")

        print("-------------------------------------")
        print("runTests: testing (# {}, errors: {}/{}): {}".format(i, self.file_compare_errors, self.file_count, cmd))
        #console.set_level("diagnostics")

        file_utils.ensure_dir_exists(logs_dir)

        xt_cmds.main(cmd)

        if self.compare:
            self.compare_submit_logs(logs_dir)
Exemple #17
0
def set_secret(name, value):
    name = correct_name(name)
    console.diag("set_secret: name={}, value={}".format(name, value))

    file_utils.ensure_dir_exists(file=FN_SECRETS)

    secrets = {}

    # read existing secrets, if any
    if os.path.exists(FN_SECRETS):
        text = file_utils.read_text_file(FN_SECRETS)
        secrets = json.loads(text)

    secrets[name] = value

    # write updates secrets
    text = json.dumps(secrets)
    file_utils.write_text_file(FN_SECRETS, text)
def get_default_config_path():
    '''
    always call this function to find the "default_config.yaml" file.
    calling this ensures that the file has been copied from its package location.
    '''
    res_dir = get_resource_dir()
    fn = os.path.join(res_dir, constants.FN_DEFAULT_CONFIG)

    if not os.path.exists(fn):
        # copy it from its helpers dir in the installed package (or dev directory)
        file_utils.ensure_dir_exists(res_dir)
        fn_source = os.path.join(file_utils.get_xtlib_dir(), "helpers",
                                 constants.FN_DEFAULT_CONFIG)
        shutil.copyfile(fn_source, fn)

        # make file readonly
        file_utils.make_readonly(fn)

    return fn
    def _start_xt_cache_server(self):

        import subprocess
        DETACHED_PROCESS = 0x00000008
        CREATE_NO_WINDOW = 0x08000000

        # launch in visible window for debugging
        MAKE_SERVER_VISIBLE = False

        xtlib_dir = os.path.dirname(__file__)
        fn_script = "{}/cache_server.py".format(xtlib_dir)
        fn_log = os.path.expanduser("~/.xt/tmp/cache_server.log")
        file_utils.ensure_dir_exists(file=fn_log)

        if MAKE_SERVER_VISIBLE:
            #subprocess.Popen(parts, cwd=".", creationflags=DETACHED_PROCESS)
            cmd = "start python " + fn_script

            os.system(cmd)
        elif pc_utils.is_windows():
            # run detached, hidden for WINDOWS
            parts = ["cmd", "/c", "python", fn_script]
            flags = CREATE_NO_WINDOW

            with open(fn_log, 'w') as output:
                subprocess.Popen(parts,
                                 cwd=".",
                                 creationflags=flags,
                                 stdout=output,
                                 stderr=subprocess.STDOUT)
        else:
            # run detached, hidden for LINUX
            parts = ["python", fn_script]

            with open(fn_log, 'w') as output:
                subprocess.Popen(parts,
                                 cwd=".",
                                 stdout=output,
                                 stderr=subprocess.STDOUT)

        # give it time to start-up and receive commands
        time.sleep(2)
    def init_dirs(self, args):
        # set mnt_output_dir (using environment variable setting from xt)
        mnt_output_dir = os.getenv("XT_OUTPUT_MNT", "output")
        mnt_output_dir = os.path.expanduser(mnt_output_dir)
        file_utils.ensure_dir_exists(mnt_output_dir)
        print("writing mnt_output to: " + mnt_output_dir)

        # set local_output_dir (using environment variable setting from xt)
        local_output_dir = "output"
        file_utils.ensure_dir_exists(local_output_dir)
        print("writing local_output to: " + local_output_dir)

        # set data_dir (allowing overridden by environment variable)
        data_dir = os.getenv("XT_DATA_DIR", args.data)
        data_dir = os.path.expanduser(data_dir)
        file_utils.ensure_dir_exists(data_dir)
        print("getting data from: " + data_dir)

        fn_test = data_dir + "/MNIST/processed/test.pt"
        exists = os.path.exists(fn_test)
        print("fn_test={}, exists={}".format(fn_test, exists))

        fn_train = data_dir + "/MNIST/processed/training.pt"
        exists = os.path.exists(fn_train)
        print("fn_train={}, exists={}".format(fn_train, exists))

        if args.download_only:
            print("miniMnist (ensuring data is downloaded)")
            self.get_dataget_dataset(data_dir, True, True)
            self.get_dataset(data_dir, False, True)

        return mnt_output_dir, local_output_dir, data_dir
Exemple #21
0
    def generate_help(self, dest_dir):
        file_utils.ensure_dir_exists(dest_dir)

        cmds = qfe.get_commands()
        count = 0

        for cmd in cmds:
            if cmd["hidden"]:
                continue

            cmd_name = cmd["name"].replace(" ", "_")

            fn = "{}/{}.rst".format(dest_dir, cmd_name)
            text = self.generate_help_cmd(cmd)

            # write text to .RST file
            with open(fn, "wt") as outfile:
                outfile.write(text)

            count += 1

        console.print("{} files generated to: {}".format(count, dest_dir))
Exemple #22
0
def write_script_file(script_lines, fn, for_windows):
    '''
    args:
        - script_lines: a list of strings (NOT newline terminated)
        - fn: path of file to create
        - for_windows: if True, lines will be written to end with CR + NEWLINE

    return: 
        - the update filename (with "~" expanded)
    '''
    fn = os.path.expanduser(fn)
    file_utils.ensure_dir_exists(file=fn)

    # set the newline joiner according to the target OS
    newline = "\r\n" if for_windows else "\n"
    text = newline.join(script_lines)

    # specify newline="" here to prevent open() from messing with our newlines
    with open(fn, "wt", newline="") as outfile:
        if not for_windows:
            # remove any rouge CR characters
            text = text.replace("\r", "")
        outfile.write(text)

    if not for_windows:
        # ensure no CR characters are found
        with open(fn, "rb") as infile:
            byte_buff = infile.read()
            if 13 in byte_buff:
                console.print("WARNING: write_script_file failed to remove all CR chars")

    #console.print("for_windows=", for_windows, "newline=", newline, ", script_lines=", script_lines) 
    # test_text = file_utils.read_text_file(fn)
    # console.print("test_text=", test_text)
    
    return fn   
Exemple #23
0
    def __init__(self, wildcard_path):
        # path = '.'
        # wildcard = "*.tfevents.*"

        wildcard_path = os.path.expanduser(wildcard_path)
        wildcard_path = wildcard_path.replace("\\", "/")

        if "*" in wildcard_path:
            path = os.path.dirname(wildcard_path)
            wildcard = os.path.basename(wildcard_path)
        else:
            path = wildcard_path
            wildcard = None

        path = file_utils.fix_slashes(path)
        #console.print("WatchWorker: path={}, wildcard={}".format(path, wildcard))

        # in case program will create dir, but it hasn't yet been created
        file_utils.ensure_dir_exists(path)

        self.event_handler = MyHandler()
        self.observer = Observer()
        #console.print("WATCHING: " + path)
        self.observer.schedule(self.event_handler, path, recursive=True)
    def _download_files(self, container, path, wc_target, dest_folder):
        #console.print("ws_name=", ws_name, ", ws_wildcard=", ws_wildcard)
        files_copied = []

        names = self._list_wild_blobs(container,
                                      path,
                                      wc_target,
                                      include_folder_names=True)
        console.diag("_download_files: names=", names)

        blob_dir = path
        bd_index = 1 + len(blob_dir)  # add for for trailing slash
        #console.print("blob_dir=", blob_dir, ", bd_index=", bd_index)

        for bn in names:
            base_bn = bn[bd_index:]
            dest_fn = dest_folder + "/" + base_bn
            console.detail("_download_files: bn=", bn, ", dest_fn=", dest_fn)

            file_utils.ensure_dir_exists(file=dest_fn)
            self.provider.get_blob_to_path(container, bn, dest_fn)
            files_copied.append(dest_fn)

        return files_copied
Exemple #25
0
    def adjust_job_for_controller_run(self, job_id, job_runs, cmds, using_hp,
                                      experiment, service_type, snapshot_dir,
                                      search_style, args):
        ''' 
        submit direct job:
            - backend.commands: some internal prep cmds and the run command(s) specified by the user
            - backend.env_var: these are set to pass a small subset of the context for the runs
            - backend.source_files: the user's source files, the run's context file

        submit controller job:
            - backend.commands: some internal prep cmds and a command line to run the XT controller
            - backend.env_var: not used
            - backend.source_files: the user's source files, controller MULTI_RUN_CONTEXT file, controller script file
        '''

        # for EACH NODE, collect and adjust runs
        context_by_nodes = {}

        for i, box_runs in enumerate(job_runs):

            node_context = self.create_context_file_core(box_runs,
                                                         i,
                                                         job_id,
                                                         using_hp,
                                                         app_info=None,
                                                         exper_name=experiment,
                                                         args=args)

            node_id = "node" + str(i)
            context_by_nodes[node_id] = node_context

            new_box_runs = self.adjust_box_runs_for_controller(i, box_runs)
            job_runs[i] = new_box_runs

        # write 1st file to SNAPSHOT (MRC file)
        fn_context = snapshot_dir + "/" + constants.FN_MULTI_RUN_CONTEXT
        file_utils.ensure_dir_exists(file=fn_context)

        mrc_data = {
            "context_by_nodes": context_by_nodes,
            "cmds": cmds,
            "search_style": search_style
        }
        text = json.dumps(mrc_data, indent=4)
        with open(fn_context, "wt") as tfile:
            tfile.write(text)

        # also write the MRC file to the job store (to support wrapup of runs after job is cancelled)
        self.store.create_job_file(job_id, constants.FN_MULTI_RUN_CONTEXT,
                                   text)

        # write 2rd file to SNAPSHOT
        fn_script = snapshot_dir + "/" + constants.PY_RUN_CONTROLLER
        is_aml = (service_type == "aml")

        with open(fn_script, "wt") as outfile:
            external_controller_port = constants.CONTROLLER_PORT

            text = ""
            text += "import sys\n"
            text += "sys.path.insert(0, '.')    # support for --xtlib-upload \n"
            text += "from xtlib.controller import run\n"
            text += "run(multi_run_context_fn='{}', port={}, is_aml={})\n".format(
                constants.FN_MULTI_RUN_CONTEXT, external_controller_port,
                is_aml)

            outfile.write(text)

        utils.copy_to_submit_logs(args, fn_context)
        utils.copy_to_submit_logs(args, fn_script)

        return [fn_context, fn_script]
Exemple #26
0
    def download(self,
                 store_path,
                 local_path,
                 share,
                 workspace,
                 experiment,
                 job,
                 run,
                 feedback,
                 snapshot,
                 show_output=True):

        use_blobs = True
        use_multi = True  # default until we test if store_path exists as a file/blob
        download_count = 0

        fs = self.create_file_accessor(use_blobs, share, workspace, experiment,
                                       job, run)

        # test for existance of store_path as a blob/file
        if not "*" in store_path and not "?" in store_path:
            if fs.does_file_exist(store_path):
                use_multi = False

        if local_path:
            # exapnd ~/ in front of local path
            local_path = os.path.expanduser(local_path)
        else:
            # path not specified for local
            if use_multi:
                local_path = "."
            else:
                local_path = "./" + os.path.basename(store_path)

        uri = fs.get_uri(store_path)

        # default store folder to recursive
        if use_multi and not "*" in store_path and not "?" in store_path:
            store_path += "/**"

        use_snapshot = snapshot

        feedback_progress = FeedbackProgress(feedback, show_output)
        progress_callback = feedback_progress.progress if feedback else None

        if use_multi:
            # download MULTI blobs/files

            what = "blobs" if use_blobs else "files"
            single_what = what[0:-1]

            if show_output:
                console.print("collecting {} names from: {}...".format(
                    single_what, uri),
                              end="")

            _, blob_names = fs.get_filenames(store_path, full_paths=False)

            if show_output:
                console.print()

            if len(blob_names) == 0:
                console.print("no matching {} found in: {}".format(what, uri))
                return 0
            elif len(blob_names) == 1:
                what = "blob" if use_blobs else "file"

            if show_output:
                console.print("\ndownloading {} {}...:".format(
                    len(blob_names), what))

            file_utils.ensure_dir_exists(local_path)
            max_name_len = max(
                [len(local_path + "/" + name) for name in blob_names])
            name_width = 1 + max_name_len
            #console.print("max_name_len=", max_name_len, ", name_width=", name_width)

            for f, bn in enumerate(blob_names):
                dest_fn = file_utils.fix_slashes(local_path + "/" + bn)

                if show_output:
                    file_msg = "file {}/{}".format(1 + f, len(blob_names))
                    console.print("  {2:}: {1:<{0:}} ".format(
                        name_width, dest_fn + ":", file_msg),
                                  end="",
                                  flush=True)

                feedback_progress.start()
                full_bn = uri + "/" + bn if uri else bn
                fs.download_file(full_bn,
                                 dest_fn,
                                 progress_callback=progress_callback,
                                 use_snapshot=use_snapshot)
                feedback_progress.end()

                download_count += 1
        else:
            # download SINGLE blobs/files
            what = "blob" if use_blobs else "file"

            if not fs.does_file_exist(store_path):
                errors.store_error("{} not found: {}".format(what, uri))

            local_path = file_utils.fix_slashes(local_path)

            if show_output:
                console.print("\nfrom {}, downloading {}:".format(uri, what))
                console.print("  {}:    ".format(local_path),
                              end="",
                              flush=True)

            feedback_progress.start()
            fs.download_file(store_path,
                             local_path,
                             progress_callback=progress_callback,
                             use_snapshot=use_snapshot)
            feedback_progress.end()

            download_count += 1

        return download_count
    def __init__(self, storage_creds):
        self.path = os.path.expanduser(storage_creds["path"])
        self.retry = None

        # create directory, if needed
        file_utils.ensure_dir_exists(self.path)
Exemple #28
0
def main():
    started = time.time()

    #print("args=", sys.argv)
    args = parse_args()

    run, model, device, train_loader, test_loader, train_writer, test_writer =\
        init_stuff(args)

    start_epoch = 1

    # log hyperparameters to xt
    hp_dict = {
        "seed": args.seed,
        "batch-size": args.batch_size,
        "epochs": args.epochs,
        "lr": args.lr,
        "momentum": args.momentum,
        "channels1": args.channels1,
        "channels2": args.channels2,
        "kernel_size": args.kernel_size,
        "mlp-units": args.mlp_units,
        "weight-decay": args.weight_decay,
        "optimizer": args.optimizer,
        "mid-conv": args.mid_conv,
        "gpu": args.gpu,
        "parallel": args.parallel,
        "distributed": args.distributed
    }

    if run:
        run.log_hparams(hp_dict)

    # console.print hyperparameters
    print("hyperparameters:", hp_dict)
    print()

    # see if we are resuming a preempted run
    if run and run.resume_name:
        print("resuming from run=", run.resume_name)
        dd = run.get_checkpoint(fn_checkpoint)
        if dd and dd["epoch"]:
            model.load_state_dict(torch.load(fn_checkpoint))
            start_epoch = 1 + dd["epoch"]

    if args.optimizer == "sgd":
        #print("using SGD optimizer")
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    else:
        #print("using Adam optimizer")
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    if args.distributed:
        optimizer = hvd.DistributedOptimizer(
            optimizer, named_parameters=model.named_parameters())

        # Broadcast parameters from rank 0 to all other processes.
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    checkpoint_freq = 0
    checkpoint_units = ""
    last_checkpoint = time.time()
    checkpoint_count = 0

    # force a ML app error to kill the app
    #x = foo/bar

    # parse checkpoint arg
    #print("args.checkpoint=", args.checkpoint, ", type(args.checkpoint)", type(args.checkpoint))

    if False:  # args.checkpoint:
        if type(args.checkpoint) in ["int", "float"]:
            checkpoint_freq = int(args.checkpoint)
            checkpoint_units = "epochs"
        elif isinstance(args.checkpoint, str):
            parts = args.checkpoint.split(' ')
            if len(parts) == 2:
                checkpoint_freq, checkpoint_units = parts
                checkpoint_freq = float(checkpoint_freq)
                checkpoint_units = checkpoint_units.strip().lower()
            else:
                checkpoint_freq = float(args.checkpoint)
                checkpoint_units = "epochs"

    model_dir = os.getenv("XT_MODEL_DIR", "models/miniMnist")
    fn_model = model_dir + "/mnist_cnn.pt"

    if args.eval_model:
        # load model and evaluate it
        print("loading existing MODEL and evaluating it, fn=", fn_model)

        model.load_state_dict(torch.load(fn_model))

        train_test_loop(run,
                        model,
                        device,
                        train_loader,
                        test_loader,
                        optimizer,
                        start_epoch,
                        checkpoint_freq,
                        train_writer,
                        test_writer,
                        test_only=True,
                        args=args)
    else:
        train_test_loop(run,
                        model,
                        device,
                        train_loader,
                        test_loader,
                        optimizer,
                        start_epoch,
                        checkpoint_freq,
                        train_writer,
                        test_writer,
                        test_only=False,
                        args=args)

    if (args.save_model):
        file_utils.ensure_dir_exists(model_dir)
        save_model(model, fn_model)

    if args.clear_checkpoint_at_end:
        if checkpoint_freq and run and run.store:
            run.clear_checkpoint()

    # console.print speed test at end
    # started = time.time()

    # text = ""
    # for i in range(100):
    #     text += "$$: this is to test # {} out console.print speed output.  it seems to be much faster on the CONTROLLER console that on the ATTACHED console.  ".format(i+1)

    # print(text)

    # elapsed = time.time() - started
    # print("console.print test ended (elapsed: {:2f} secs)".format(elapsed))

    if train_writer:
        train_writer.close()
        test_writer.close()

    if run:
        # ensure we log end of run for AML
        run.close()

    elapsed = time.time() - started
    print("\n--- miniMnist elapsed: {:.0f} secs ---".format(elapsed))
    def run(self):

        print("args=", sys.argv)
        self.args = parse_cmdline_args()
        args = self.args

        fn_runset = "runset.yaml"
        if os.path.exists(fn_runset):
            self.apply_runset_file(args, fn_runset)

        model, device, mnt_output_dir, local_output_dir = self.init_stuff()

        start_epoch = 1
        run = self.run

        if args.raise_error:
            #errors.internal_error("Raising an intentional error")
            # try a different type of error
            abc.foo = 1

        # log hyperparameters to xt
        if run:
            hp_dict = {
                "seed": args.seed,
                "batch-size": args.batch_size,
                "epochs": args.epochs,
                "lr": args.lr,
                "momentum": args.momentum,
                "channels1": args.channels1,
                "channels2": args.channels2,
                "kernel_size": args.kernel_size,
                "mlp-units": args.mlp_units,
                "weight-decay": args.weight_decay,
                "optimizer": args.optimizer,
                "mid-conv": args.mid_conv,
                "gpu": args.gpu,
                "log-interval": args.log_interval
            }

            run.log_hparams(hp_dict)

        if args.cuda:
            # if on linux, show GPU info
            if os.name != "nt":
                os.system("nvidia-smi")

        # print hyperparameters
        print("hyperparameters:", hp_dict)
        print()

        # see if we are resuming a preempted run
        if run and run.resume_name:
            print("resuming from run=", run.resume_name)
            dd = run.get_checkpoint(fn_checkpoint)
            if dd and dd["epoch"]:
                model.load_state_dict(torch.load(fn_checkpoint))
                start_epoch = 1 + dd["epoch"]

        if args.optimizer == "sgd":
            #print("using SGD optimizer")
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
        else:
            #print("using Adam optimizer")
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.weight_decay)

        if args.distributed:
            optimizer = hvd.DistributedOptimizer(
                optimizer, named_parameters=model.named_parameters())

            # Broadcast parameters from rank 0 to all other processes.
            hvd.broadcast_parameters(model.state_dict(), root_rank=0)

        checkpoint_freq = 0
        checkpoint_units = ""
        last_checkpoint = time.time()
        checkpoint_count = 0

        # force a ML app error to kill the app
        #x = foo/bar

        # parse checkpoint arg
        #print("args.checkpoint=", args.checkpoint, ", type(args.checkpoint)", type(args.checkpoint))

        if False:  # args.checkpoint:
            if type(args.checkpoint) in ["int", "float"]:
                checkpoint_freq = int(args.checkpoint)
                checkpoint_units = "epochs"
            elif isinstance(args.checkpoint, str):
                parts = args.checkpoint.split(' ')
                if len(parts) == 2:
                    checkpoint_freq, checkpoint_units = parts
                    checkpoint_freq = float(checkpoint_freq)
                    checkpoint_units = checkpoint_units.strip().lower()
                else:
                    checkpoint_freq = float(args.checkpoint)
                    checkpoint_units = "epochs"

        model_dir = os.getenv("XT_MODEL_DIR", "models/miniMnist")
        fn_model = model_dir + "/mnist_cnn.pt"
        self.fn_text_log = mnt_output_dir + "/text_log.txt"

        if args.eval_model:
            # load model and evaluate it
            print("loading existing MODEL and evaluating it, fn=", fn_model)
            exists = os.path.exists(fn_model)
            print("model exists=", exists)

            model.load_state_dict(torch.load(fn_model))
            print("model loaded!")

            # just test model
            self.test_model_and_log_metrics(run,
                                            model,
                                            device,
                                            epoch=1,
                                            args=args)
        else:
            self.train_test_loop(run,
                                 model,
                                 device,
                                 optimizer,
                                 1,
                                 checkpoint_freq,
                                 args=args)

        if (args.save_model):
            file_utils.ensure_dir_exists(model_dir)
            self.save_model(model, fn_model)

        # always save a copy of model in the AFTER FILES
        self.save_model(model, "output/mnist_cnn.pt")

        if args.clear_checkpoint_at_end:
            if checkpoint_freq and run and run.store:
                run.clear_checkpoint()

        # create a file to be captured in OUTPUT FILES
        fn_app_log = os.path.join(local_output_dir, "miniMnist_log.txt")
        with open(fn_app_log, "wt") as outfile:
            outfile.write("This is a log for miniMnist app\n")
            outfile.write("miniMnist app completed\n")

        # create a file to be ignored in OUTPUT FILES
        fn_app_log = os.path.join(local_output_dir, "test.junk")
        with open(fn_app_log, "wt") as outfile:
            outfile.write(
                "This is a file that should be omitted from AFTER upload\n")
            outfile.write("end of junk file\n")

        if run:
            # ensure we close all logging
            run.close()