def _get_pipeline_catalog_from_kedro14(env): try: pipeline = get_project_context("create_pipeline")() get_config = get_project_context("get_config") conf = get_config(str(Path.cwd()), env) create_catalog = get_project_context("create_catalog") catalog = create_catalog(config=conf) return pipeline, catalog except (ImportError, KeyError): raise KedroCliError(ERROR_PROJECT_ROOT)
def _call_viz( host=None, port=None, browser=None, load_file=None, save_file=None, pipeline_name=None, env=None, project_path=None, ): global _DATA # pylint: disable=global-statement,invalid-name global _CATALOG # pylint: disable=global-statement if load_file: # Remove all handlers for root logger root_logger = logging.getLogger() root_logger.handlers = [] _DATA = _load_from_file(load_file) else: if KEDRO_VERSION.match(">=0.15.0"): # pylint: disable=import-outside-toplevel if KEDRO_VERSION.match(">=0.16.0"): from kedro.framework.context import KedroContextError else: from kedro.context import ( # pylint: disable=no-name-in-module,import-error KedroContextError, ) try: if project_path is not None: context = get_project_context("context", project_path=project_path, env=env) else: context = get_project_context("context", env=env) pipelines = _get_pipelines_from_context(context, pipeline_name) except KedroContextError: raise KedroCliError(ERROR_PROJECT_ROOT) _CATALOG = context.catalog else: # Kedro 0.14.* if pipeline_name: raise KedroCliError(ERROR_PIPELINE_FLAG_NOT_SUPPORTED) pipelines, _CATALOG = _get_pipeline_catalog_from_kedro14(env) _DATA = format_pipelines_data(pipelines) if save_file: Path(save_file).write_text(json.dumps(_DATA, indent=4, sort_keys=True)) else: is_localhost = host in ("127.0.0.1", "localhost", "0.0.0.0") if browser and is_localhost: webbrowser.open_new("http://{}:{:d}/".format(host, port)) app.run(host=host, port=port)
def _get_pipeline_catalog_from_kedro14( env, ) -> Tuple[Dict[str, "Pipeline"], "DataCatalog"]: try: pipeline = get_project_context("create_pipeline")() get_config = get_project_context("get_config") conf = get_config(str(Path.cwd()), env) create_catalog = get_project_context("create_catalog") catalog = create_catalog(config=conf) return {_DEFAULT_KEY: pipeline}, catalog except (ImportError, KeyError): raise KedroCliError(ERROR_PROJECT_ROOT)
def grpc_serve( context: Any = None, host: str = "[::]", port: int = 50051, max_workers: int = 10, wait_term: bool = True, ): """ Start the Kedro gRPC server :param context: Kedro Project Context :param host: host for running the grpc server :param port: Port to run the gRPC server on :param max_workers: Max number of workers :param wait_term: Wait for termination :raises KedroGrpcServerException: Failing to start gRPC Server """ try: if not context: context = get_project_context() server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers)) add_KedroServicer_to_server(KedroServer(context), server) server.add_insecure_port(f"{host}:{port}") server.start() logging.info("Kedro gRPC Server started on %s", port) if wait_term: # pragma: no cover server.wait_for_termination() except Exception as exc: logging.error(exc) raise KedroGrpcServerException("Failed to start Kedro gRPC Server")
def test_verbose(self): assert not get_project_context("verbose")
def test_project_path(self): key = "project_path" pattern = self._deprecation_msg(key) with warns(DeprecationWarning, match=pattern): assert get_project_context(key) == "dummy_path"
def test_template_version(self): key = "template_version" pattern = self._deprecation_msg(key) with warns(DeprecationWarning, match=pattern): assert get_project_context(key) == "dummy_version"
def test_create_pipeline(self): key = "create_pipeline" pattern = self._deprecation_msg(key) with warns(DeprecationWarning, match=pattern): pipeline = get_project_context(key) assert pipeline() == "pipeline"
def test_create_catalog(self): key = "create_catalog" pattern = self._deprecation_msg(key) with warns(DeprecationWarning, match=pattern): catalog = get_project_context(key) assert catalog("config") == "catalog"
def test_get_config(self, tmp_path): key = "get_config" pattern = self._deprecation_msg(key) with warns(DeprecationWarning, match=pattern): config_loader = get_project_context(key) assert config_loader(tmp_path) == "config_loader"
def test_context(self): dummy_context = get_project_context("context") assert isinstance(dummy_context, DummyContext)
def test_get_context_with_project_path(self, tmpdir, mocked_load_context): dummy_project_path = tmpdir.mkdir("dummy_project") dummy_context = get_project_context("context", project_path=dummy_project_path) mocked_load_context.assert_called_once_with(dummy_project_path) assert isinstance(dummy_context, DummyContext)
def test_get_context_without_project_path(self, mocked_load_context): dummy_context = get_project_context("context") mocked_load_context.assert_called_once_with(Path.cwd()) assert isinstance(dummy_context, DummyContext)