Beispiel #1
0
def diff(repo, *args, a_rev=None, b_rev=None, param_deps=False, **kwargs):
    from dvc.repo.experiments.show import _collect_experiment_commit
    from dvc.scm import resolve_rev

    if repo.scm.no_commits:
        return {}

    if a_rev:
        a_rev = fix_exp_head(repo.scm, a_rev)
        rev = resolve_rev(repo.scm, a_rev)
        old = _collect_experiment_commit(repo, rev, param_deps=param_deps)
    else:
        old = _collect_experiment_commit(
            repo, fix_exp_head(repo.scm, "HEAD"), param_deps=param_deps
        )

    if b_rev:
        b_rev = fix_exp_head(repo.scm, b_rev)
        rev = resolve_rev(repo.scm, b_rev)
        new = _collect_experiment_commit(repo, rev, param_deps=param_deps)
    else:
        new = _collect_experiment_commit(
            repo, "workspace", param_deps=param_deps
        )

    with_unchanged = kwargs.pop("all", False)

    return {
        key: _diff(
            format_dict(old.get("data", {}).get(key, {})),
            format_dict(new.get("data", {}).get(key, {})),
            with_unchanged=with_unchanged,
        )
        for key in ["metrics", "params"]
    }
Beispiel #2
0
def _collect_experiment_branch(res,
                               repo,
                               branch,
                               baseline,
                               onerror: Optional[Callable] = None,
                               **kwargs):
    from dvc.scm import resolve_rev

    exp_rev = resolve_rev(repo.scm, branch)
    prev = None
    revs = list(repo.scm.branch_revs(exp_rev, baseline))
    for rev in revs:
        collected_exp = _collect_experiment_commit(repo,
                                                   rev,
                                                   onerror=onerror,
                                                   **kwargs)
        if len(revs) > 1:
            exp = {"checkpoint_tip": exp_rev}
            if prev:
                res[prev]["data"][  # type: ignore[unreachable]
                    "checkpoint_parent"] = rev
            if rev in res:
                res[rev]["data"].update(exp)
                res.move_to_end(rev)
            else:
                exp.update(collected_exp["data"])
        else:
            exp = collected_exp["data"]
        if rev not in res:
            res[rev] = {"data": exp}
        prev = rev
    if len(revs) > 1:
        res[prev]["data"]["checkpoint_parent"] = baseline
    return res
Beispiel #3
0
def test_subdir(tmp_dir, scm, dvc, workspace):
    subdir = tmp_dir / "dir"
    subdir.gen("copy.py", COPY_SCRIPT)
    subdir.gen("params.yaml", "foo: 1")

    with subdir.chdir():
        dvc.run(
            cmd="python copy.py params.yaml metrics.yaml",
            metrics_no_cache=["metrics.yaml"],
            params=["foo"],
            name="copy-file",
            no_exec=True,
        )
        scm.add(
            [subdir / "dvc.yaml", subdir / "copy.py", subdir / "params.yaml"])
        scm.commit("init")

        results = dvc.experiments.run(PIPELINE_FILE,
                                      params=["foo=2"],
                                      tmp_dir=not workspace)
        assert results

    exp = first(results)
    ref_info = first(exp_refs_by_rev(scm, exp))

    fs = scm.get_fs(exp)
    for fname in ["metrics.yaml", "dvc.lock"]:
        assert fs.exists(subdir / fname)
    with fs.open(subdir / "metrics.yaml", mode="r", encoding="utf-8") as fobj:
        assert fobj.read().strip() == "foo: 2"

    assert dvc.experiments.get_exact_name(exp) == ref_info.name
    assert resolve_rev(scm, ref_info.name) == exp
Beispiel #4
0
def ls(repo, *args, rev=None, git_remote=None, all_=False, **kwargs):
    from scmrepo.git import Git

    from dvc.scm import RevError, resolve_rev

    if rev:
        try:
            rev = resolve_rev(repo.scm, rev)
        except RevError:
            if not (git_remote and Git.is_sha(rev)):
                # This could be a remote rev that has not been fetched yet
                raise
    elif not all_:
        rev = repo.scm.get_rev()

    results = defaultdict(list)

    if rev:
        if git_remote:
            gen = remote_exp_refs_by_baseline(repo.scm, git_remote, rev)
        else:
            gen = exp_refs_by_baseline(repo.scm, rev)
        for info in gen:
            results[rev].append(info.name)
    elif all_:
        if git_remote:
            gen = remote_exp_refs(repo.scm, git_remote)
        else:
            gen = exp_refs(repo.scm)
        for info in gen:
            results[info.baseline_sha].append(info.name)

    return results
Beispiel #5
0
    def reproduce_one(
        self,
        queue: bool = False,
        tmp_dir: bool = False,
        checkpoint_resume: Optional[str] = None,
        reset: bool = False,
        **kwargs,
    ):
        """Reproduce and checkout a single experiment."""
        if queue and not checkpoint_resume:
            reset = True

        if reset:
            self.reset_checkpoints()

        if not (queue or tmp_dir):
            staged, _, _ = self.scm.status()
            if staged:
                logger.warning(
                    "Your workspace contains staged Git changes which will be "
                    "unstaged before running this experiment.")
                self.scm.reset()

        if checkpoint_resume:
            from dvc.scm import resolve_rev

            resume_rev = resolve_rev(self.scm, checkpoint_resume)
            try:
                self.check_baseline(resume_rev)
                checkpoint_resume = resume_rev
            except BaselineMismatchError as exc:
                raise DvcException(
                    f"Cannot resume from '{checkpoint_resume}' as it is not "
                    "derived from your current workspace.") from exc
        else:
            checkpoint_resume = self._workspace_resume_rev()

        stash_rev = self.new(checkpoint_resume=checkpoint_resume,
                             reset=reset,
                             **kwargs)
        if queue:
            logger.info("Queued experiment '%s' for future execution.",
                        stash_rev[:7])
            return [stash_rev]
        if tmp_dir or queue:
            manager_cls: Type = TempDirExecutorManager
        else:
            manager_cls = WorkspaceExecutorManager
        results = self._reproduce_revs(
            revs=[stash_rev],
            keep_stash=False,
            manager_cls=manager_cls,
        )
        exp_rev = first(results)
        if exp_rev is not None:
            self._log_reproduced(results, tmp_dir=tmp_dir)
        return results
Beispiel #6
0
def branch(repo, exp_rev, branch_name, *args, **kwargs):
    from dvc.scm import resolve_rev

    try:
        rev = resolve_rev(repo.scm, exp_rev)
    except RevError:
        raise InvalidArgumentError(exp_rev)
    ref_info = None

    ref_infos = list(exp_refs_by_rev(repo.scm, rev))
    if len(ref_infos) == 1:
        ref_info = ref_infos[0]
    elif len(ref_infos) > 1:
        current_rev = repo.scm.get_rev()
        for info in ref_infos:
            if info.baseline_sha == current_rev:
                ref_info = info
                break
        if not ref_info:
            msg = [
                f"Ambiguous experiment name '{exp_rev}' can refer to "
                "multiple experiments. To create a branch use a full "
                "experiment ref:",
                "",
            ]
            msg.extend([str(info) for info in ref_infos])
            raise InvalidArgumentError("\n".join(msg))

    if not ref_info:
        raise InvalidExpRevError(exp_rev)

    branch_ref = f"refs/heads/{branch_name}"
    if repo.scm.get_ref(branch_ref):
        raise InvalidArgumentError(
            f"Git branch '{branch_name}' already exists."
        )

    target = repo.scm.get_ref(str(ref_info))
    repo.scm.set_ref(
        branch_ref,
        target,
        message=f"dvc: Created from experiment '{ref_info.name}'",
    )
    fmt = (
        "Git branch '%s' has been created from experiment '%s'.\n"
        "To switch to the new branch run:\n\n"
        "\tgit checkout %s"
    )
    logger.info(fmt, branch_name, ref_info.name, branch_name)
Beispiel #7
0
def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
    stash_revs = repo.experiments.stash_revs
    for _, ref_info in stash_revs.items():
        if ref_info.name == ref_or_rev:
            return ref_info.index

    from dvc.scm import resolve_rev

    try:
        rev = resolve_rev(repo.scm, ref_or_rev)
        if rev in stash_revs:
            return stash_revs.get(rev).index
    except RevError:
        pass
    return None
Beispiel #8
0
    def _get_baseline(self, rev):
        from dvc.scm import resolve_rev

        rev = resolve_rev(self.scm, rev)

        if rev in self.stash_revs:
            entry = self.stash_revs.get(rev)
            if entry:
                return entry.baseline_rev
            return None

        ref_info = first(exp_refs_by_rev(self.scm, rev))
        if ref_info:
            return ref_info.baseline_sha
        return None
Beispiel #9
0
Datei: apply.py Projekt: jear/dvc
def apply(repo: "Repo", rev: str, force: bool = True, **kwargs):
    from scmrepo.exceptions import SCMError as _SCMError

    from dvc.repo.checkout import checkout as dvc_checkout
    from dvc.scm import GitMergeError, RevError, resolve_rev

    exps = repo.experiments

    try:
        exp_rev = resolve_rev(repo.scm, rev)
        exps.check_baseline(exp_rev)
    except (RevError, BaselineMismatchError) as exc:
        raise InvalidExpRevError(rev) from exc

    stash_rev = exp_rev in exps.stash_revs
    if not stash_rev and not exps.get_branch_by_rev(exp_rev,
                                                    allow_multiple=True):
        raise InvalidExpRevError(exp_rev)

    # NOTE: we don't use scmrepo's stash_workspace() here since we need
    # finer control over the merge behavior when we unstash everything
    with _apply_workspace(repo, rev, force):
        try:
            repo.scm.merge(exp_rev, commit=False, squash=True)
        except _SCMError as exc:
            raise GitMergeError(str(exc), scm=repo.scm)

    repo.scm.reset()

    if stash_rev:
        args_path = os.path.join(repo.tmp_dir, BaseExecutor.PACKED_ARGS_FILE)
        if os.path.exists(args_path):
            remove(args_path)

    dvc_checkout(repo, **kwargs)

    repo.scm.set_ref(EXEC_APPLY, exp_rev)
    logger.info(
        "Changes for experiment '%s' have been applied to your current "
        "workspace.",
        rev,
    )
Beispiel #10
0
def test_new_simple(tmp_dir, scm, dvc, exp_stage, mocker, name, workspace):
    baseline = scm.get_rev()
    tmp_dir.gen("params.yaml", "foo: 2")

    new_mock = mocker.spy(dvc.experiments, "new")
    results = dvc.experiments.run(exp_stage.addressing,
                                  name=name,
                                  tmp_dir=not workspace)
    exp = first(results)
    ref_info = first(exp_refs_by_rev(scm, exp))
    assert ref_info and ref_info.baseline_sha == baseline

    new_mock.assert_called_once()
    fs = scm.get_fs(exp)
    with fs.open(tmp_dir / "metrics.yaml", mode="r", encoding="utf-8") as fobj:
        assert fobj.read().strip() == "foo: 2"

    if workspace:
        assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 2"

    exp_name = name if name else ref_info.name
    assert dvc.experiments.get_exact_name(exp) == exp_name
    assert resolve_rev(scm, exp_name) == exp
Beispiel #11
0
def apply(repo, rev, force=True, **kwargs):
    from dvc.repo.checkout import checkout as dvc_checkout
    from dvc.scm import RevError, SCMError, resolve_rev
    from dvc.scm.exceptions import MergeConflictError

    exps = repo.experiments

    try:
        exp_rev = resolve_rev(repo.scm, rev)
        exps.check_baseline(exp_rev)
    except (RevError, BaselineMismatchError) as exc:
        raise InvalidExpRevError(rev) from exc

    stash_rev = exp_rev in exps.stash_revs
    if not stash_rev and not exps.get_branch_by_rev(exp_rev,
                                                    allow_multiple=True):
        raise InvalidExpRevError(exp_rev)

    # Note that we don't use stash_workspace() here since we need finer control
    # over the merge behavior when we unstash everything
    if repo.scm.is_dirty(untracked_files=True):
        logger.debug("Stashing workspace")
        workspace = repo.scm.stash.push(include_untracked=True)
    else:
        workspace = None

    from dvc.scm.exceptions import SCMError as _SCMError

    try:
        repo.scm.merge(exp_rev, commit=False)
    except _SCMError as exc:
        raise SCMError(str(exc))

    if workspace:
        try:
            repo.scm.stash.apply(workspace)
        except MergeConflictError as exc:
            # Applied experiment conflicts with user's workspace changes
            if force:
                # prefer applied experiment changes over prior stashed changes
                repo.scm.checkout_index(ours=True)
            else:
                # revert applied changes and restore user's workspace
                repo.scm.reset(hard=True)
                repo.scm.stash.pop()
                raise ApplyConflictError(rev) from exc
        except _SCMError as exc:
            raise ApplyConflictError(rev) from exc
        repo.scm.stash.drop()
    repo.scm.reset()

    if stash_rev:
        args_path = os.path.join(repo.tmp_dir, BaseExecutor.PACKED_ARGS_FILE)
        if os.path.exists(args_path):
            remove(args_path)

    dvc_checkout(repo, **kwargs)

    repo.scm.set_ref(EXEC_APPLY, exp_rev)
    logger.info(
        "Changes for experiment '%s' have been applied to your current "
        "workspace.",
        rev,
    )
Beispiel #12
0
def show(
    repo,
    all_branches=False,
    all_tags=False,
    revs=None,
    all_commits=False,
    sha_only=False,
    num=1,
    param_deps=False,
    onerror: Optional[Callable] = None,
):
    if onerror is None:
        onerror = onerror_collect

    res: Dict[str, Dict] = defaultdict(OrderedDict)

    if num < 1:
        raise InvalidArgumentError(f"Invalid number of commits '{num}'")

    if revs is None:
        from dvc.scm import RevError, resolve_rev

        revs = []
        for n in range(num):
            try:
                head = fix_exp_head(repo.scm, f"HEAD~{n}")
                assert head
                revs.append(resolve_rev(repo.scm, head))
            except RevError:
                break

    revs = OrderedDict((rev, None) for rev in repo.brancher(
        revs=revs,
        all_branches=all_branches,
        all_tags=all_tags,
        all_commits=all_commits,
        sha_only=True,
    ))

    running = repo.experiments.get_running_exps()

    for rev in revs:
        res[rev]["baseline"] = _collect_experiment_commit(
            repo,
            rev,
            sha_only=sha_only,
            param_deps=param_deps,
            running=running,
            onerror=onerror,
        )

        if rev == "workspace":
            continue

        ref_info = ExpRefInfo(baseline_sha=rev)
        commits = [(ref, repo.scm.resolve_commit(ref))
                   for ref in repo.scm.iter_refs(base=str(ref_info))]
        for exp_ref, _ in sorted(commits,
                                 key=lambda x: x[1].commit_time,
                                 reverse=True):
            ref_info = ExpRefInfo.from_ref(exp_ref)
            assert ref_info.baseline_sha == rev
            _collect_experiment_branch(
                res[rev],
                repo,
                exp_ref,
                rev,
                sha_only=sha_only,
                param_deps=param_deps,
                running=running,
                onerror=onerror,
            )
        # collect queued (not yet reproduced) experiments
        for stash_rev, entry in repo.experiments.stash_revs.items():
            if entry.baseline_rev in revs:
                if stash_rev not in running or not running[stash_rev].get(
                        "last"):
                    experiment = _collect_experiment_commit(
                        repo,
                        stash_rev,
                        sha_only=sha_only,
                        stash=stash_rev not in running,
                        param_deps=param_deps,
                        running=running,
                        onerror=onerror,
                    )
                    res[entry.baseline_rev][stash_rev] = experiment
    return res