Ejemplo n.º 1
0
def restore(ctx, run, no_git, branch, project, entity):
    if ":" in run:
        if "/" in run:
            entity, rest = run.split("/", 1)
        else:
            rest = run
        project, run = rest.split(":", 1)
    elif run.count("/") > 1:
        entity, run = run.split("/", 1)

    project, run = api.parse_slug(run, project=project)
    commit, json_config, patch_content, metadata = api.run_config(
        project, run=run, entity=entity)
    repo = metadata.get("git", {}).get("repo")
    image = metadata.get("docker")
    RESTORE_MESSAGE = """`wandb restore` needs to be run from the same git repository as the original run.
Run `git clone %s` and restore from there or pass the --no-git flag.""" % repo
    if no_git:
        commit = None
    elif not api.git.enabled:
        if repo:
            raise ClickException(RESTORE_MESSAGE)
        elif image:
            wandb.termlog(
                "Original run has no git history.  Just restoring config and docker"
            )

    if commit and api.git.enabled:
        subprocess.check_call(['git', 'fetch', '--all'])
        try:
            api.git.repo.commit(commit)
        except ValueError:
            wandb.termlog("Couldn't find original commit: {}".format(commit))
            commit = None
            files = api.download_urls(project, run=run, entity=entity)
            for filename in files:
                if filename.startswith('upstream_diff_') and filename.endswith(
                        '.patch'):
                    commit = filename[len('upstream_diff_'):-len('.patch')]
                    try:
                        api.git.repo.commit(commit)
                    except ValueError:
                        commit = None
                    else:
                        break

            if commit:
                wandb.termlog(
                    "Falling back to upstream commit: {}".format(commit))
                patch_path, _ = api.download_write_file(files[filename])
            else:
                raise ClickException(RESTORE_MESSAGE)
        else:
            if patch_content:
                patch_path = os.path.join(wandb.wandb_dir(), 'diff.patch')
                with open(patch_path, "w") as f:
                    f.write(patch_content)
            else:
                patch_path = None

        branch_name = "wandb/%s" % run
        if branch and branch_name not in api.git.repo.branches:
            api.git.repo.git.checkout(commit, b=branch_name)
            wandb.termlog("Created branch %s" %
                          click.style(branch_name, bold=True))
        elif branch:
            wandb.termlog(
                "Using existing branch, run `git branch -D %s` from master for a clean checkout"
                % branch_name)
            api.git.repo.git.checkout(branch_name)
        else:
            wandb.termlog("Checking out %s in detached mode" % commit)
            api.git.repo.git.checkout(commit)

        if patch_path:
            # we apply the patch from the repository root so git doesn't exclude
            # things outside the current directory
            root = api.git.root
            patch_rel_path = os.path.relpath(patch_path, start=root)
            # --reject is necessary or else this fails any time a binary file
            # occurs in the diff
            # we use .call() instead of .check_call() for the same reason
            # TODO(adrian): this means there is no error checking here
            subprocess.call(['git', 'apply', '--reject', patch_rel_path],
                            cwd=root)
            wandb.termlog("Applied patch")

    # TODO: we should likely respect WANDB_DIR here.
    util.mkdir_exists_ok("wandb")
    config = Config(run_dir="wandb")
    config.load_json(json_config)
    config.persist()
    wandb.termlog("Restored config variables to %s" % config._config_path())
    if image:
        if not metadata["program"].startswith("<") and metadata.get(
                "args") is not None:
            # TODO: we may not want to default to python here.
            runner = util.find_runner(metadata["program"]) or ["python"]
            command = runner + [metadata["program"]] + metadata["args"]
            cmd = " ".join(command)
        else:
            wandb.termlog(
                "Couldn't find original command, just restoring environment")
            cmd = None
        wandb.termlog("Docker image found, attempting to start")
        ctx.invoke(docker, docker_run_args=[image], cmd=cmd)

    return commit, json_config, patch_content, repo, metadata
Ejemplo n.º 2
0
def restore(run, branch, project, entity):
    project, run = api.parse_slug(run, project=project)
    commit, json_config, patch_content = api.run_config(project,
                                                        run=run,
                                                        entity=entity)
    subprocess.check_call(['git', 'fetch', '--all'])

    if commit:
        try:
            api.git.repo.commit(commit)
        except ValueError:
            click.echo("Couldn't find original commit: {}".format(commit))
            commit = None
            files = api.download_urls(project, run=run, entity=entity)
            for filename in files:
                if filename.startswith('upstream_diff_') and filename.endswith(
                        '.patch'):
                    commit = filename[len('upstream_diff_'):-len('.patch')]
                    try:
                        api.git.repo.commit(commit)
                    except ValueError:
                        commit = None
                    else:
                        break

            if commit:
                click.echo(
                    "Falling back to upstream commit: {}".format(commit))
                patch_path, _ = api.download_write_file(files[filename])
            else:
                raise ClickException(
                    "Can't find commit from which to restore code")
        else:
            if patch_content:
                patch_path = os.path.join(wandb.wandb_dir(), 'diff.patch')
                with open(patch_path, "w") as f:
                    f.write(patch_content)
            else:
                patch_path = None

        branch_name = "wandb/%s" % run
        if branch and branch_name not in api.git.repo.branches:
            api.git.repo.git.checkout(commit, b=branch_name)
            click.echo("Created branch %s" %
                       click.style(branch_name, bold=True))
        elif branch:
            click.secho(
                "Using existing branch, run `git branch -D %s` from master for a clean checkout"
                % branch_name,
                fg="red")
            api.git.repo.git.checkout(branch_name)
        else:
            click.secho("Checking out %s in detached mode" % commit)
            api.git.repo.git.checkout(commit)

        if patch_path:
            # we apply the patch from the repository root so git doesn't exclude
            # things outside the current directory
            root = api.git.root
            patch_rel_path = os.path.relpath(patch_path, start=root)
            # --reject is necessary or else this fails any time a binary file
            # occurs in the diff
            # we use .call() instead of .check_call() for the same reason
            # TODO(adrian): this means there is no error checking here
            subprocess.call(['git', 'apply', '--reject', patch_rel_path],
                            cwd=root)
            click.echo("Applied patch")

    config = Config()
    config.load_json(json_config)
    config.persist()
    click.echo("Restored config variables")