Exemple #1
0
def validate_prompts(repo: "Repo", key: str,
                     value: str) -> Union[Any, Tuple[Any, str]]:
    from dvc.ui.prompt import InvalidResponse

    msg_format = "[yellow]'{0}' does not exist, the {1} will be created.[/]"
    if key == "params":
        from dvc.dependency.param import (
            MissingParamsFile,
            ParamsDependency,
            ParamsIsADirectoryError,
        )

        assert isinstance(value, str)
        try:
            ParamsDependency(None, value, repo=repo).validate_filepath()
        except MissingParamsFile:
            return value, msg_format.format(value, "file")
        except ParamsIsADirectoryError:
            raise InvalidResponse(
                f"[prompt.invalid]'{value}' is a directory. "
                "Please retry with an existing parameters file.")
    elif key in ("code", "data"):
        if not os.path.exists(value):
            typ = "file" if is_file(value) else "directory"
            return value, msg_format.format(value, typ)
    return value
Exemple #2
0
def _get(stage, p, info):
    parsed = urlparse(p) if p else None
    if parsed and parsed.scheme == "remote":
        fs = get_cloud_fs(stage.repo, name=parsed.netloc)
        return DEP_MAP[fs.scheme](stage, p, info, fs=fs)

    if info and info.get(RepoDependency.PARAM_REPO):
        repo = info.pop(RepoDependency.PARAM_REPO)
        return RepoDependency(repo, stage, p, info)

    if info and info.get(ParamsDependency.PARAM_PARAMS):
        params = info.pop(ParamsDependency.PARAM_PARAMS)
        return ParamsDependency(stage, p, params)

    dep_cls = DEP_MAP.get(parsed.scheme, LocalDependency)
    return dep_cls(stage, p, info)
Exemple #3
0
def _get(stage, p, info):
    parsed = urlparse(p) if p else None
    if parsed and parsed.scheme == "remote":
        tree = get_cloud_tree(stage.repo, name=parsed.netloc)
        return DEP_MAP[tree.scheme](stage, p, info, tree=tree)

    if info and info.get(RepoDependency.PARAM_REPO):
        repo = info.pop(RepoDependency.PARAM_REPO)
        return RepoDependency(repo, stage, p, info)

    if info and info.get(ParamsDependency.PARAM_PARAMS):
        params = info.pop(ParamsDependency.PARAM_PARAMS)
        return ParamsDependency(stage, p, params)

    for d in DEPS:
        if d.supported(p):
            return d(stage, p, info)
    return LocalDependency(stage, p, info)
Exemple #4
0
def loads_params(stage, s_list):
    d = _merge_params(s_list)
    return [
        ParamsDependency(stage, path, params) for path, params in d.items()
    ]
Exemple #5
0
def init(
    repo: "Repo",
    name: str = "train",
    type: str = "default",  # pylint: disable=redefined-builtin
    defaults: Dict[str, str] = None,
    overrides: Dict[str, str] = None,
    interactive: bool = False,
    force: bool = False,
    stream: Optional[TextIO] = None,
) -> Tuple[PipelineStage, List["Dependency"], List[str]]:
    from dvc.dvcfile import make_dvcfile

    dvcfile = make_dvcfile(repo, "dvc.yaml")
    _check_stage_exists(dvcfile, name, force=force)

    defaults = defaults.copy() if defaults else {}
    overrides = overrides.copy() if overrides else {}

    if interactive:
        defaults = init_interactive(
            validator=partial(validate_prompts, repo),
            defaults=defaults,
            provided=overrides,
            stream=stream,
        )
    else:
        if "live" in overrides:
            # suppress `metrics`/`plots` if live is selected.
            defaults.pop("metrics", None)
            defaults.pop("plots", None)
        else:
            defaults.pop("live", None)  # suppress live otherwise

    context: Dict[str, str] = {**defaults, **overrides}
    assert "cmd" in context

    params = context.get("params")
    if params:
        from dvc.dependency.param import (
            MissingParamsFile,
            ParamsDependency,
            ParamsIsADirectoryError,
        )

        try:
            ParamsDependency(None, params, repo=repo).validate_filepath()
        except ParamsIsADirectoryError as exc:
            raise DvcException(f"{exc}.")  # swallow cause for display
        except MissingParamsFile:
            pass

    models = context.get("models")
    live_path = context.pop("live", None)
    live_metrics = f"{live_path}.json" if live_path else None
    live_plots = os.path.join(live_path, "scalars") if live_path else None

    stage = repo.stage.create(
        name=name,
        cmd=context["cmd"],
        deps=compact([context.get("code"),
                      context.get("data")]),
        params=[{
            params: None
        }] if params else None,
        metrics_no_cache=compact([context.get("metrics"), live_metrics]),
        plots_no_cache=compact([context.get("plots"), live_plots]),
        force=force,
        **{
            "checkpoints" if type == "checkpoint" else "outs":
            compact([models])
        },
    )

    with _disable_logging(), repo.scm_context(autostage=True, quiet=True):
        stage.dump(update_lock=False)
        initialized_out_dirs = init_out_dirs(stage)
        stage.ignore_outs()
        initialized_deps = init_deps(stage)
        if params:
            repo.scm_context.track_file(params)

    assert isinstance(stage, PipelineStage)
    return stage, initialized_deps, initialized_out_dirs