Ejemplo n.º 1
0
    def _dump_pipeline_file(self, stage):
        data = {}
        if self.exists():
            with open(self.path) as fd:
                data = parse_yaml_for_update(fd.read(), self.path)
        else:
            logger.info("Creating '%s'", self.relpath)
            open(self.path, "w+").close()

        data["stages"] = data.get("stages", {})
        stage_data = serialize.to_pipeline_file(stage)
        existing_entry = stage.name in data["stages"]

        action = "Modifying" if existing_entry else "Adding"
        logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath)

        if existing_entry:
            orig_stage_data = data["stages"][stage.name]
            if "meta" in orig_stage_data:
                stage_data[stage.name]["meta"] = orig_stage_data["meta"]
            apply_diff(stage_data[stage.name], orig_stage_data)
        else:
            data["stages"].update(stage_data)

        dump_yaml(self.path, data)
        self.repo.scm.track_file(self.relpath)
Ejemplo n.º 2
0
    def _dump_pipeline_file(self, stage):
        self._check_if_parametrized(stage)
        stage_data = serialize.to_pipeline_file(stage)

        with modify_yaml(self.path, fs=self.repo.fs) as data:
            if not data:
                logger.info("Creating '%s'", self.relpath)

            data["stages"] = data.get("stages", {})
            existing_entry = stage.name in data["stages"]
            action = "Modifying" if existing_entry else "Adding"
            logger.info("%s stage '%s' in '%s'", action, stage.name,
                        self.relpath)

            if existing_entry:
                orig_stage_data = data["stages"][stage.name]
                apply_diff(stage_data[stage.name], orig_stage_data)
            else:
                data["stages"].update(stage_data)

        self.repo.scm.track_file(self.relpath)
Ejemplo n.º 3
0
def init(
    repo: "Repo",
    name: str = None,
    type: str = "default",  # pylint: disable=redefined-builtin
    defaults: Dict[str, str] = None,
    overrides: Dict[str, str] = None,
    interactive: bool = False,
    force: bool = False,
    stream: Optional[TextIO] = None,
) -> "Stage":
    from dvc.dvcfile import make_dvcfile

    dvcfile = make_dvcfile(repo, "dvc.yaml")
    name = name or type

    _check_stage_exists(dvcfile, name, force=force)

    defaults = defaults.copy() if defaults else {}
    overrides = overrides.copy() if overrides else {}

    with_live = type == "live"

    if interactive:
        defaults = init_interactive(
            name,
            validator=validate_prompts,
            defaults=defaults,
            live=with_live,
            provided=overrides,
            stream=stream,
        )
    else:
        if with_live:
            # suppress `metrics`/`plots` if live is selected, unless
            # it is also provided via overrides/cli.
            # This makes output to be a checkpoint as well.
            defaults.pop("metrics", None)
            defaults.pop("plots", None)
        else:
            defaults.pop("live", None)  # suppress live otherwise

    context: Dict[str, str] = {**defaults, **overrides}
    assert "cmd" in context

    params_kv = []
    params = context.get("params")
    if params:
        params_kv.append(loadd_params(params))

    checkpoint_out = bool(context.get("live"))
    models = context.get("models")
    stage = repo.stage.create(
        name=name,
        cmd=context["cmd"],
        deps=compact([context.get("code"),
                      context.get("data")]),
        params=params_kv,
        metrics_no_cache=compact([context.get("metrics")]),
        plots_no_cache=compact([context.get("plots")]),
        live=context.get("live"),
        force=force,
        **{"checkpoints" if checkpoint_out else "outs": compact([models])},
    )

    if interactive:
        ui.error_write(Rule(style="green"), styled=True)
        _yaml = dumps_yaml(to_pipeline_file(cast(PipelineStage, stage)))
        syn = Syntax(_yaml, "yaml", theme="ansi_dark")
        ui.error_write(syn, styled=True)

    from dvc.ui.prompt import Confirm

    if not interactive or Confirm.ask(
            "Do you want to add the above contents to dvc.yaml?",
            console=ui.error_console,
            default=True,
            stream=stream,
    ):
        with _disable_logging(), repo.scm_context(autostage=True, quiet=True):
            stage.dump(update_lock=False)
            stage.ignore_outs()
            if params:
                repo.scm_context.track_file(params)
    else:
        raise DvcException("Aborting ...")
    return stage
Ejemplo n.º 4
0
def init(
    repo: "Repo",
    name: str = None,
    type: str = "default",  # pylint: disable=redefined-builtin
    defaults: Dict[str, str] = None,
    overrides: Dict[str, str] = None,
    interactive: bool = False,
    force: bool = False,
) -> "Stage":
    from dvc.dvcfile import make_dvcfile

    dvcfile = make_dvcfile(repo, "dvc.yaml")
    name = name or type

    _check_stage_exists(dvcfile, name, force=force)

    defaults = defaults or {}
    overrides = overrides or {}

    with_live = type == "live"
    if interactive:
        defaults = init_interactive(
            name,
            defaults=defaults,
            live=with_live,
            provided=overrides,
            show_tree=True,
        )
    else:
        if with_live:
            # suppress `metrics`/`params` if live is selected, unless
            # it is also provided via overrides/cli.
            # This makes output to be a checkpoint as well.
            defaults.pop("metrics")
            defaults.pop("params")
        else:
            defaults.pop("live")  # suppress live otherwise

    context: Dict[str, str] = {**defaults, **overrides}
    assert "cmd" in context

    params_kv = []
    if context.get("params"):
        from dvc.utils.serialize import LOADERS

        path = context["params"]
        assert isinstance(path, str)
        _, ext = os.path.splitext(path)
        params_kv = [{path: list(LOADERS[ext](path))}]

    checkpoint_out = bool(context.get("live"))
    models = context.get("models")
    stage = repo.stage.create(
        name=name,
        cmd=context["cmd"],
        deps=compact([context.get("code"),
                      context.get("data")]),
        params=params_kv,
        metrics_no_cache=compact([context.get("metrics")]),
        plots_no_cache=compact([context.get("plots")]),
        live=context.get("live"),
        force=force,
        **{"checkpoints" if checkpoint_out else "outs": compact([models])},
    )

    if interactive:
        ui.write(Rule(style="green"), styled=True)
        _yaml = dumps_yaml(to_pipeline_file(cast(PipelineStage, stage)))
        syn = Syntax(_yaml, "yaml", theme="ansi_dark")
        ui.error_write(syn, styled=True)

    if not interactive or ui.confirm(
            "Do you want to add the above contents to dvc.yaml?"):
        with _disable_logging():
            stage.dump(update_lock=False)
        stage.ignore_outs()
    else:
        raise DvcException("Aborting ...")
    return stage