def _dump_pipeline_file(self, stage): data = {} if self.exists(): with open(self.path) as fd: data = parse_yaml_for_update(fd.read(), self.path) else: logger.info("Creating '%s'", self.relpath) open(self.path, "w+").close() data["stages"] = data.get("stages", {}) stage_data = serialize.to_pipeline_file(stage) existing_entry = stage.name in data["stages"] action = "Modifying" if existing_entry else "Adding" logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath) if existing_entry: orig_stage_data = data["stages"][stage.name] if "meta" in orig_stage_data: stage_data[stage.name]["meta"] = orig_stage_data["meta"] apply_diff(stage_data[stage.name], orig_stage_data) else: data["stages"].update(stage_data) dump_yaml(self.path, data) self.repo.scm.track_file(self.relpath)
def _dump_pipeline_file(self, stage): self._check_if_parametrized(stage) stage_data = serialize.to_pipeline_file(stage) with modify_yaml(self.path, fs=self.repo.fs) as data: if not data: logger.info("Creating '%s'", self.relpath) data["stages"] = data.get("stages", {}) existing_entry = stage.name in data["stages"] action = "Modifying" if existing_entry else "Adding" logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath) if existing_entry: orig_stage_data = data["stages"][stage.name] apply_diff(stage_data[stage.name], orig_stage_data) else: data["stages"].update(stage_data) self.repo.scm.track_file(self.relpath)
def init( repo: "Repo", name: str = None, type: str = "default", # pylint: disable=redefined-builtin defaults: Dict[str, str] = None, overrides: Dict[str, str] = None, interactive: bool = False, force: bool = False, stream: Optional[TextIO] = None, ) -> "Stage": from dvc.dvcfile import make_dvcfile dvcfile = make_dvcfile(repo, "dvc.yaml") name = name or type _check_stage_exists(dvcfile, name, force=force) defaults = defaults.copy() if defaults else {} overrides = overrides.copy() if overrides else {} with_live = type == "live" if interactive: defaults = init_interactive( name, validator=validate_prompts, defaults=defaults, live=with_live, provided=overrides, stream=stream, ) else: if with_live: # suppress `metrics`/`plots` if live is selected, unless # it is also provided via overrides/cli. # This makes output to be a checkpoint as well. defaults.pop("metrics", None) defaults.pop("plots", None) else: defaults.pop("live", None) # suppress live otherwise context: Dict[str, str] = {**defaults, **overrides} assert "cmd" in context params_kv = [] params = context.get("params") if params: params_kv.append(loadd_params(params)) checkpoint_out = bool(context.get("live")) models = context.get("models") stage = repo.stage.create( name=name, cmd=context["cmd"], deps=compact([context.get("code"), context.get("data")]), params=params_kv, metrics_no_cache=compact([context.get("metrics")]), plots_no_cache=compact([context.get("plots")]), live=context.get("live"), force=force, **{"checkpoints" if checkpoint_out else "outs": compact([models])}, ) if interactive: ui.error_write(Rule(style="green"), styled=True) _yaml = dumps_yaml(to_pipeline_file(cast(PipelineStage, stage))) syn = Syntax(_yaml, "yaml", theme="ansi_dark") ui.error_write(syn, styled=True) from dvc.ui.prompt import Confirm if not interactive or Confirm.ask( "Do you want to add the above contents to dvc.yaml?", console=ui.error_console, default=True, stream=stream, ): with _disable_logging(), repo.scm_context(autostage=True, quiet=True): stage.dump(update_lock=False) stage.ignore_outs() if params: repo.scm_context.track_file(params) else: raise DvcException("Aborting ...") return stage
def init( repo: "Repo", name: str = None, type: str = "default", # pylint: disable=redefined-builtin defaults: Dict[str, str] = None, overrides: Dict[str, str] = None, interactive: bool = False, force: bool = False, ) -> "Stage": from dvc.dvcfile import make_dvcfile dvcfile = make_dvcfile(repo, "dvc.yaml") name = name or type _check_stage_exists(dvcfile, name, force=force) defaults = defaults or {} overrides = overrides or {} with_live = type == "live" if interactive: defaults = init_interactive( name, defaults=defaults, live=with_live, provided=overrides, show_tree=True, ) else: if with_live: # suppress `metrics`/`params` if live is selected, unless # it is also provided via overrides/cli. # This makes output to be a checkpoint as well. defaults.pop("metrics") defaults.pop("params") else: defaults.pop("live") # suppress live otherwise context: Dict[str, str] = {**defaults, **overrides} assert "cmd" in context params_kv = [] if context.get("params"): from dvc.utils.serialize import LOADERS path = context["params"] assert isinstance(path, str) _, ext = os.path.splitext(path) params_kv = [{path: list(LOADERS[ext](path))}] checkpoint_out = bool(context.get("live")) models = context.get("models") stage = repo.stage.create( name=name, cmd=context["cmd"], deps=compact([context.get("code"), context.get("data")]), params=params_kv, metrics_no_cache=compact([context.get("metrics")]), plots_no_cache=compact([context.get("plots")]), live=context.get("live"), force=force, **{"checkpoints" if checkpoint_out else "outs": compact([models])}, ) if interactive: ui.write(Rule(style="green"), styled=True) _yaml = dumps_yaml(to_pipeline_file(cast(PipelineStage, stage))) syn = Syntax(_yaml, "yaml", theme="ansi_dark") ui.error_write(syn, styled=True) if not interactive or ui.confirm( "Do you want to add the above contents to dvc.yaml?"): with _disable_logging(): stage.dump(update_lock=False) stage.ignore_outs() else: raise DvcException("Aborting ...") return stage