示例#1
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContextManager.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            updated_exec_state = ctx.execution_state.with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(updated_exec_state)):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return exception_scopes.user_entry_point(task_function)(
                    **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            return self.compile_into_workflow(ctx, task_function, **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
            return exception_scopes.user_entry_point(task_function)(**kwargs)

        raise ValueError(
            f"Invalid execution provided, execution state: {ctx.execution_state}"
        )
示例#2
0
def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]):
    """
    This is a decorator used for testing.
    """
    if (not isinstance(target, PythonTask)
            and not isinstance(target, WorkflowBase)
            and not isinstance(target, ReferenceEntity)):
        raise Exception(
            "Can only use mocks on tasks/workflows declared in Python.")

    logger.info(
        "When using this patch function on Flyte entities, please be aware weird issues may arise if also"
        "using mock.patch on internal Flyte classes like PythonFunctionWorkflow. See"
        "https://github.com/flyteorg/flyte/issues/854 for more information")

    def wrapper(test_fn):
        def new_test(*args, **kwargs):
            logger.warning(f"Invoking mock method for target: '{target.name}'")
            m = MagicMock()
            saved = target.execute
            target.execute = m
            results = test_fn(m, *args, **kwargs)
            target.execute = saved
            return results

        return new_test

    return wrapper
示例#3
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the _local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContextManager.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            updated_exec_state = ctx.execution_state.with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(updated_exec_state)):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return task_function(**kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            is_fast_execution = bool(
                ctx.execution_state and ctx.execution_state.additional_context
                and ctx.execution_state.additional_context.get(
                    "dynamic_addl_distro"))
            if is_fast_execution:
                ctx = ctx.with_serialization_settings(
                    SerializationSettings.new_builder(
                    ).with_fast_serialization_settings(
                        FastSerializationSettings(enabled=True)).build())

            return self.compile_into_workflow(ctx, task_function, **kwargs)
示例#4
0
 def get_protocol(url: str):
     # copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391
     parts = re.split(r"(\:\:|\://)", url, 1)
     if len(parts) > 1:
         return parts[0]
     logger.info("Setting protocol to file")
     return "file"
示例#5
0
def check_call(cmd_args, **kwargs):
    if not isinstance(cmd_args, list):
        cmd_args = _schlex.split(cmd_args)

    # Jupyter notebooks hijack I/O and thus we cannot dump directly to stdout.
    with _tempfile.TemporaryFile() as std_out:
        with _tempfile.TemporaryFile() as std_err:
            ret_code = _subprocess.Popen(cmd_args,
                                         stdout=std_out,
                                         stderr=std_err,
                                         **kwargs).wait()

            # Dump sub-process' std out into current std out
            std_out.seek(0)
            logger.info("Output of command '{}':\n{}\n".format(
                cmd_args, std_out.read()))

            if ret_code != 0:
                std_err.seek(0)
                err_str = std_err.read()
                logger.error("Error from command '{}':\n{}\n".format(
                    cmd_args, err_str))

                raise Exception(
                    "Called process exited with error code: {}.  Stderr dump:\n\n{}"
                    .format(ret_code, err_str))

    return 0
示例#6
0
def _register():
    logger.info(
        "Registering fsspec known implementations and overriding all default implementations for persistence."
    )
    DataPersistencePlugins.register_plugin("/", FSSpecPersistence, force=True)
    for k, v in known_implementations.items():
        DataPersistencePlugins.register_plugin(f"{k}://",
                                               FSSpecPersistence,
                                               force=True)
示例#7
0
 def __exit__(self, exc_type, exc_val, exc_tb):
     end_wall_time = _time.perf_counter()
     end_process_time = _time.process_time()
     logger.info(
         "Exiting timed context: {} [Wall Time: {}s, Process Time: {}s]".
         format(
             self._context_statement,
             end_wall_time - self._start_wall_time,
             end_process_time - self._start_process_time,
         ))
示例#8
0
文件: deck.py 项目: flyteorg/flytekit
def _output_deck(task_name: str, new_user_params: ExecutionParameters):
    ctx = FlyteContext.current_context()
    if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
        output_dir = ctx.execution_state.engine_dir
    else:
        output_dir = ctx.file_access.get_random_local_directory()
    deck_path = os.path.join(output_dir, DECK_FILE_NAME)
    with open(deck_path, "w") as f:
        f.write(_get_deck(new_user_params))
    logger.info(
        f"{task_name} task creates flyte deck html to file://{deck_path}")
示例#9
0
    def execute(self, **kwargs) -> typing.Any:
        """
        Executes the given script by substituting the inputs and outputs and extracts the outputs from the filesystem
        """
        logger.info(f"Running shell script as type {self.task_type}")
        if self.script_file:
            with open(self.script_file) as f:
                self._script = f.read()

        outputs: typing.Dict[str, str] = {}
        if self._output_locs:
            for v in self._output_locs:
                outputs[v.var] = self._interpolizer.interpolate(v.location,
                                                                inputs=kwargs)

        if os.name == "nt":
            self._script = self._script.lstrip().rstrip().replace("\n", "&&")

        if "env" in kwargs and isinstance(kwargs["env"], dict):
            kwargs["export_env"] = self.make_export_string_from_env_dict(
                kwargs["env"])

        gen_script = self._interpolizer.interpolate(self._script,
                                                    inputs=kwargs,
                                                    outputs=outputs)
        if self._debug:
            print("\n==============================================\n")
            print(gen_script)
            print("\n==============================================\n")

        try:
            subprocess.check_call(gen_script, shell=True)
        except subprocess.CalledProcessError as e:
            files = os.listdir(".")
            fstr = "\n-".join(files)
            logger.error(
                f"Failed to Execute Script, return-code {e.returncode} \n"
                f"StdErr: {e.stderr}\n"
                f"StdOut: {e.stdout}\n"
                f" Current directory contents: .\n-{fstr}")
            raise

        final_outputs = []
        for v in self._output_locs:
            if issubclass(v.var_type, FlyteFile):
                final_outputs.append(FlyteFile(outputs[v.var]))
            if issubclass(v.var_type, FlyteDirectory):
                final_outputs.append(FlyteDirectory(outputs[v.var]))
        if len(final_outputs) == 1:
            return final_outputs[0]
        if len(final_outputs) > 1:
            return tuple(final_outputs)
        return None
    def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask:
        logger.info(f"Task template loader args: {loader_args}")
        ctx = FlyteContext.current_context()
        task_template_local_path = os.path.join(
            ctx.execution_state.working_dir, "task_template.pb")
        ctx.file_access.get_data(loader_args[0], task_template_local_path)
        task_template_proto = common_utils.load_proto_from_file(
            _tasks_pb2.TaskTemplate, task_template_local_path)
        task_template_model = _task_model.TaskTemplate.from_flyte_idl(
            task_template_proto)

        executor_class = load_object_from_module(loader_args[1])
        return ExecutableTemplateShimTask(task_template_model, executor_class)
示例#11
0
 def _finder(handler_map, df_type: Type, protocol: str, format: str):
     try:
         return handler_map[df_type][protocol][format]
     except KeyError:
         try:
             hh = handler_map[df_type][protocol][""]
             logger.info(
                 f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}"
                 f" format {format}, using default instead.")
             return hh
         except KeyError:
             ...
     raise ValueError(
         f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}"
     )
示例#12
0
    def unwrap_literal_map_and_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[VoidPromise, _literal_models.LiteralMap,
               _dynamic_job.DynamicJobSpec]:
        """
        Please see the implementation of the dispatch_execute function in the real task.
        """

        # Invoked before the task is executed
        # Translate the input literals to Python native
        native_inputs = TypeEngine.literal_map_to_kwargs(
            ctx, input_literal_map, self.python_interface.inputs)

        logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
        try:
            native_outputs = self.execute(**native_inputs)
        except Exception as e:
            logger.exception(f"Exception when executing {e}")
            raise e
        logger.info(
            f"Task executed successfully in user level, outputs: {native_outputs}"
        )

        expected_output_names = list(self.python_interface.outputs.keys())
        if len(expected_output_names) == 1:
            native_outputs_as_map = {expected_output_names[0]: native_outputs}
        elif len(expected_output_names) == 0:
            native_outputs_as_map = {}
        else:
            native_outputs_as_map = {
                expected_output_names[i]: native_outputs[i]
                for i, _ in enumerate(native_outputs)
            }

        # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
        # built into the IDL that all the values of a literal map are of the same type.
        literals = {}
        for k, v in native_outputs_as_map.items():
            literal_type = self.interface.outputs[k].type
            py_type = self.python_interface.outputs[k]
            if isinstance(v, tuple):
                raise AssertionError(
                    f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}"
                )
            literals[k] = TypeEngine.to_literal(ctx, v, py_type, literal_type)
        outputs_literal_map = _literal_models.LiteralMap(literals=literals)
        # After the execute has been successfully completed
        return outputs_literal_map
示例#13
0
    def _local_execute(
            self, ctx: FlyteContext,
            **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
        """
        Performs local execution of a workflow. kwargs are expected to be Promises for the most part (unless,
        someone has hardcoded in my_wf(input_1=5) or something).
        :param ctx: The FlyteContext
        :param kwargs: parameters for the workflow itself
        """
        logger.info(
            f"Executing Workflow {self._name}, ctx{ctx.execution_state.Mode}")

        # This is done to support the invariant that Workflow local executions always work with Promise objects
        # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
        for k, v in kwargs.items():
            if not isinstance(v, Promise):
                t = self._native_interface.inputs[k]
                kwargs[k] = Promise(var=k,
                                    val=TypeEngine.to_literal(
                                        ctx, v, t,
                                        self.interface.inputs[k].type))

        function_outputs = self.execute(**kwargs)
        if (isinstance(function_outputs, VoidPromise)
                or function_outputs is None
                or len(self.python_interface.outputs) == 0):
            # The reason this is here is because a workflow function may return a task that doesn't return anything
            #   def wf():
            #       return t1()
            # or it may not return at all
            #   def wf():
            #       t1()
            # In the former case we get the task's VoidPromise, in the latter we get None
            return VoidPromise(self.name)

        # TODO: Can we refactor the task code to be similar to what's in this function?
        promises = _workflow_fn_outputs_to_promise(
            ctx, self._native_interface.outputs, self.interface.outputs,
            function_outputs)
        # TODO: With the native interface, create_task_output should be able to derive the typed interface, and it
        #   should be able to do the conversion of the output of the execute() call directly.
        return create_task_output(promises, self._native_interface)
示例#14
0
    def pre_execute(self,
                    user_params: ExecutionParameters) -> ExecutionParameters:
        """
        Pre-execute for Sagemaker will automatically add the distributed context to the execution params, only
        if the number of execution instances is > 1. Otherwise this is considered to be a single node execution
        """
        if self._is_distributed():
            logger.info("Distributed context detected!")
            exec_state = FlyteContextManager.current_context().execution_state
            if exec_state and exec_state.mode == ExecutionState.Mode.TASK_EXECUTION:
                """
                This mode indicates we are actually in a remote execute environment (within sagemaker in this case)
                """
                dist_ctx = DistributedTrainingContext.from_env()
            else:
                dist_ctx = DistributedTrainingContext.local_execute()
            return user_params.builder().add_attr(
                "DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build()

        return user_params
示例#15
0
def register(
    registrable_entities: typing.List[RegistrableEntity],
    project: str,
    domain: str,
    version: str,
    client: SynchronousFlyteClient,
):
    # The incoming registrable entities are already in base protobuf form, not model form, so we use the
    # raw client's methods instead of the friendly client's methods by calling super
    for admin_entity in registrable_entities:
        try:
            if isinstance(admin_entity, _idl_admin_TaskSpec):
                ident, task_spec = hydrate_registration_parameters(
                    identifier_pb2.TASK, project, domain, version,
                    admin_entity)
                logger.debug(f"Creating task {ident}")
                super(SynchronousFlyteClient, client).create_task(
                    TaskCreateRequest(id=ident, spec=task_spec))
            elif isinstance(admin_entity, _idl_admin_WorkflowSpec):
                ident, wf_spec = hydrate_registration_parameters(
                    identifier_pb2.WORKFLOW, project, domain, version,
                    admin_entity)
                logger.debug(f"Creating workflow {ident}")
                super(SynchronousFlyteClient, client).create_workflow(
                    WorkflowCreateRequest(id=ident, spec=wf_spec))
            elif isinstance(admin_entity, _idl_admin_LaunchPlan):
                ident, admin_lp = hydrate_registration_parameters(
                    identifier_pb2.LAUNCH_PLAN, project, domain, version,
                    admin_entity)
                logger.debug(f"Creating launch plan {ident}")
                super(SynchronousFlyteClient, client).create_launch_plan(
                    LaunchPlanCreateRequest(id=ident, spec=admin_lp.spec))
            else:
                raise AssertionError(
                    f"Unknown entity of type {type(admin_entity)}")
        except FlyteEntityAlreadyExistsException:
            logger.info(f"{admin_entity} already exists")
        except Exception as e:
            logger.info(
                f"Failed to register entity {admin_entity} with error {e}")
            raise e
示例#16
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the _local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContext.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            with ctx.new_execution_context(ExecutionState.Mode.TASK_EXECUTION):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return task_function(**kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            return self.compile_into_workflow(ctx, task_function, **kwargs)
示例#17
0
    def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
        """
        In the case of distributed execution, we check the should_persist_predicate in the configuration to determine
        if the output should be persisted. This is because in distributed training, multiple nodes may produce partial
        outputs and only the user process knows the output that should be generated. They can control the choice using
        the predicate.

        To control if output is generated across every execution, we override the post_execute method and sometimes
        return a None
        """
        if self._is_distributed():
            logger.info("Distributed context detected!")
            dctx = flytekit.current_context().distributed_training_context
            if not self.task_config.should_persist_output(dctx):
                logger.info(
                    "output persistence predicate not met, Flytekit will ignore outputs"
                )
                raise IgnoreOutputs(
                    f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}"
                )
        return rval
示例#18
0
文件: task.py 项目: flyteorg/flytekit
    def execute(self, **kwargs) -> Any:
        """
        TODO: Figure out how to share FlyteContext ExecutionParameters with the notebook kernel (as notebook kernel
             is executed in a separate python process)
        For Spark, the notebooks today need to use the new_session or just getOrCreate session and get a handle to the
        singleton
        """
        logger.info(
            f"Hijacking the call for task-type {self.task_type}, to call notebook."
        )
        # Execute Notebook via Papermill.
        pm.execute_notebook(self._notebook_path,
                            self.output_notebook_path,
                            parameters=kwargs)  # type: ignore

        outputs = self.extract_outputs(self.output_notebook_path)
        self.render_nb_html(self.output_notebook_path,
                            self.rendered_output_path)

        m = {}
        if outputs:
            m = outputs.literals
        output_list = []
        for k, type_v in self.python_interface.outputs.items():
            if k == self._IMPLICIT_OP_NOTEBOOK:
                output_list.append(self.output_notebook_path)
            elif k == self._IMPLICIT_RENDERED_NOTEBOOK:
                output_list.append(self.rendered_output_path)
            elif k in m:
                v = TypeEngine.to_python_value(
                    ctx=FlyteContext.current_context(),
                    lv=m[k],
                    expected_python_type=type_v)
                output_list.append(v)
            else:
                raise RuntimeError(
                    f"Expected output {k} of type {v} not found in the notebook outputs"
                )

        return tuple(output_list)
示例#19
0
文件: file.py 项目: flyteorg/flytekit
def get_config_file(c: typing.Union[str, ConfigFile, None]) -> typing.Optional[ConfigFile]:
    """
    Checks if the given argument is a file or a configFile and returns a loaded configFile else returns None
    """
    if c is None:
        # See if there's a config file in the current directory where Python is being run from
        current_location_config = Path("flytekit.config")
        if current_location_config.exists():
            logger.info(f"Using configuration from Python process root {current_location_config.absolute()}")
            return ConfigFile(str(current_location_config.absolute()))

        # If not, see if there's a config in the user's home directory
        home_dir_config = Path(Path.home(), ".flyte", "config")  # _default_config_file_name in main.py
        if home_dir_config.exists():
            logger.info(f"Using configuration from home directory {home_dir_config.absolute()}")
            return ConfigFile(str(home_dir_config.absolute()))

        # If not, see if the env var that flytectl sandbox tells the user to set is set,
        # or see if there's something in the default home directory location
        flytectl_path = Path(Path.home(), ".flyte", "config.yaml")
        flytectl_path_from_env = getenv(FLYTECTL_CONFIG_ENV_VAR, None)

        if flytectl_path_from_env:
            flytectl_path = Path(flytectl_path_from_env)
        if flytectl_path.exists():
            logger.info(f"Using flytectl/YAML config {flytectl_path.absolute()}")
            return ConfigFile(str(flytectl_path.absolute()))

        # If not, then return None and let caller handle
        return None
    if isinstance(c, str):
        logger.debug(f"Using specified config file at {c}")
        return ConfigFile(c)
    return c
示例#20
0
def _update_cmd_config_and_execute(s3_cfg: S3Config, cmd: List[str]):
    env = _os.environ.copy()

    if s3_cfg.enable_debug:
        cmd.insert(1, "--debug")

    if s3_cfg.endpoint is not None:
        cmd.insert(1, s3_cfg.endpoint)
        cmd.insert(1, "--endpoint-url")

    if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ:
        if s3_cfg.access_key_id:
            env[S3_ACCESS_KEY_ID_ENV_NAME] = s3_cfg.access_key_id

    if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ:
        if s3_cfg.secret_access_key:
            env[S3_SECRET_ACCESS_KEY_ENV_NAME] = s3_cfg.secret_access_key

    retry = 0
    while True:
        try:
            try:
                return subprocess.check_call(cmd, env=env)
            except Exception as e:
                if retry > 0:
                    logger.info(
                        f"AWS command failed with error {e}, command: {cmd}, retry {retry}"
                    )

            logger.debug(
                f"Appending anonymous flag and retrying command {cmd}")
            anonymous_cmd = cmd[:]  # strings only, so this is deep enough
            anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG)
            return subprocess.check_call(anonymous_cmd, env=env)

        except Exception as e:
            logger.error(
                f"Exception when trying to execute {cmd}, reason: {str(e)}")
            retry += 1
            if retry > s3_cfg.retries:
                raise
            secs = s3_cfg.backoff
            logger.info(
                f"Sleeping before retrying again, after {secs.total_seconds()} seconds"
            )
            time.sleep(secs.total_seconds())
            logger.info("Retrying again")
示例#21
0
"""
Flytekit PyTorch
=========================================
.. currentmodule:: flytekit.extras.pytorch

.. autosummary::
   :template: custom.rst
   :toctree: generated/

    PyTorchCheckpoint
"""
from flytekit.loggers import logger

try:
    from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer
    from .native import PyTorchModuleTransformer, PyTorchTensorTransformer
except ImportError:
    logger.info(
        "We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed."
    )
示例#22
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This function is largely similar to the base PythonTask, with the exception that we have to infer the Python
        interface before executing. Also, we refer to ``self.task_template`` rather than just ``self`` similar to task
        classes that derive from the base ``PythonTask``.
        """
        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        user_space_params=new_user_params))) as exec_ctx:
            # Added: Have to reverse the Python interface from the task template Flyte interface
            # See docstring for more details.
            guessed_python_input_types = TypeEngine.guess_python_types(
                self.task_template.interface.inputs)
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, guessed_python_input_types)

            logger.info(
                f"Invoking FlyteTask executor {self.task_template.id.name} with inputs: {native_inputs}"
            )
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.debug("Task executed successfully in user level")
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(
                self.task_template.interface.outputs.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                # Deleted some stuff
                native_outputs_as_map = {
                    expected_output_names[0]: native_outputs
                }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self.task_template.interface.outputs[k].type
                py_type = type(v)

                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.task_template.id.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
示例#23
0
    def local_execute(self, ctx: FlyteContext,
                      **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
        """
        This function is used only in the local execution path and is responsible for calling dispatch execute.
        Use this function when calling a task with native values (or Promises containing Flyte literals derived from
        Python native values).
        """
        # Unwrap the kwargs values. After this, we essentially have a LiteralMap
        # The reason why we need to do this is because the inputs during local execute can be of 2 types
        #  - Promises or native constants
        #  Promises as essentially inputs from previous task executions
        #  native constants are just bound to this specific task (default values for a task input)
        #  Also along with promises and constants, there could be dictionary or list of promises or constants
        kwargs = translate_inputs_to_literals(
            ctx,
            incoming_values=kwargs,
            flyte_interface_types=self.interface.inputs,  # type: ignore
            native_types=self.get_input_types(),
        )
        input_literal_map = _literal_models.LiteralMap(literals=kwargs)

        # if metadata.cache is set, check memoized version
        if self.metadata.cache:
            # TODO: how to get a nice `native_inputs` here?
            logger.info(
                f"Checking cache for task named {self.name}, cache version {self.metadata.cache_version} "
                f"and inputs: {input_literal_map}")
            outputs_literal_map = LocalTaskCache.get(
                self.name, self.metadata.cache_version, input_literal_map)
            # The cache returns None iff the key does not exist in the cache
            if outputs_literal_map is None:
                logger.info("Cache miss, task will be executed now")
                outputs_literal_map = self.dispatch_execute(
                    ctx, input_literal_map)
                # TODO: need `native_inputs`
                LocalTaskCache.set(self.name, self.metadata.cache_version,
                                   input_literal_map, outputs_literal_map)
                logger.info(
                    f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} "
                    f"and inputs: {input_literal_map}")
            else:
                logger.info("Cache hit")
        else:
            es = ctx.execution_state
            b = es.user_space_params.with_task_sandbox()
            ctx = ctx.current_context().with_execution_state(
                es.with_params(user_space_params=b.build())).build()
            outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
        outputs_literals = outputs_literal_map.literals

        # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
        #    location, otherwise we dont really need to right? The higher level execute could just handle literalMap
        # After running, we again have to wrap the outputs, if any, back into Promise objects
        output_names = list(self.interface.outputs.keys())  # type: ignore
        if len(output_names) != len(outputs_literals):
            # Length check, clean up exception
            raise AssertionError(
                f"Length difference {len(output_names)} {len(outputs_literals)}"
            )

        # Tasks that don't return anything still return a VoidPromise
        if len(output_names) == 0:
            return VoidPromise(self.name)

        vals = [Promise(var, outputs_literals[var]) for var in output_names]
        return create_task_output(vals, self.python_interface)
示例#24
0
    def execute(self, **kwargs) -> Any:
        context = ge.data_context.DataContext(
            self._context_root_dir)  # type: ignore

        if len(self.python_interface.inputs.keys()) != 1:
            raise TypeError(
                "Expected one input argument to validate the dataset")

        dataset_key = list(self.python_interface.inputs.keys())[0]
        dataset = kwargs[dataset_key]
        datatype = self.python_interface.inputs[dataset_key]

        if not issubclass(datatype, (FlyteFile, FlyteSchema, str)):
            raise TypeError(
                "'dataset' has to have FlyteFile/FlyteSchema/str datatype")

        # determine the type of data connector
        selected_datasource = list(
            filter(lambda x: x["name"] == self._datasource_name,
                   context.list_datasources()))

        if not selected_datasource:
            raise ValueError("Datasource doesn't exist!")

        data_connector_class_lookup = {
            data_connector_name: data_connector_class["class_name"]
            for data_connector_name, data_connector_class in
            selected_datasource[0]["data_connectors"].items()
        }

        specified_data_connector_class = data_connector_class_lookup[
            self._data_connector_name]

        is_runtime = False
        if specified_data_connector_class == "RuntimeDataConnector":
            is_runtime = True
            if not self._data_asset_name:
                raise ValueError(
                    "data_asset_name has to be given in a RuntimeBatchRequest")

        # FlyteFile
        if issubclass(datatype, FlyteFile):
            dataset = self._flyte_file(dataset)

        # FlyteSchema
        # convert schema to parquet file
        if issubclass(datatype, FlyteSchema) and not is_runtime:
            dataset = self._flyte_schema(dataset)

        # minimalistic batch request
        final_batch_request = {
            "data_asset_name":
            self._data_asset_name if is_runtime else dataset,
            "datasource_name": self._datasource_name,
            "data_connector_name": self._data_connector_name,
        }

        # Great Expectations' RuntimeBatchRequest
        if self._batch_request_config and (
                self._batch_request_config.runtime_parameters or is_runtime):
            final_batch_request.update({
                "runtime_parameters":
                self._batch_request_config.runtime_parameters
                if self._batch_request_config.runtime_parameters else {},
                "batch_identifiers":
                self._batch_request_config.batch_identifiers,
                "batch_spec_passthrough":
                self._batch_request_config.batch_spec_passthrough,
            })

            if is_runtime and issubclass(datatype, str):
                final_batch_request["runtime_parameters"]["query"] = dataset
            elif is_runtime and issubclass(datatype, FlyteSchema):
                final_batch_request["runtime_parameters"][
                    "batch_data"] = dataset.open().all()
            else:
                raise AssertionError(
                    "Can only use runtime_parameters for query(str)/schema data"
                )

        # Great Expectations' BatchRequest
        elif self._batch_request_config:
            final_batch_request.update({
                "data_connector_query":
                self._batch_request_config.data_connector_query,
                "batch_spec_passthrough":
                self._batch_request_config.batch_spec_passthrough,
            })

        if self._checkpoint_params:
            checkpoint = SimpleCheckpoint(
                f"_tmp_checkpoint_{self._expectation_suite_name}",
                context,
                **self._checkpoint_params,
            )
        else:
            checkpoint = SimpleCheckpoint(
                f"_tmp_checkpoint_{self._expectation_suite_name}",
                context,
            )

        # identify every run uniquely
        run_id = RunIdentifier(
            **{
                "run_name": self._datasource_name + "_run",
                "run_time": datetime.datetime.utcnow(),
            })

        checkpoint_result = checkpoint.run(
            run_id=run_id,
            validations=[{
                "batch_request":
                final_batch_request,
                "expectation_suite_name":
                self._expectation_suite_name,
            }],
        )
        final_result = convert_to_json_serializable(
            checkpoint_result.list_validation_results())[0]

        result_string = ""
        if final_result["success"] is False:
            for every_result in final_result["results"]:
                if every_result["success"] is False:
                    result_string += (
                        every_result["expectation_config"]["kwargs"]["column"]
                        + " -> " +
                        every_result["expectation_config"]["expectation_type"]
                        + "\n")

            # raise a Great Expectations' exception
            raise ValidationError(
                "Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" +
                result_string)

        logger.info("Validation succeeded!")

        return final_result
示例#25
0
    def __init__(
        self,
        name: str,
        task_config: T,
        task_type="python-task",
        container_image: Optional[str] = None,
        requests: Optional[Resources] = None,
        limits: Optional[Resources] = None,
        environment: Optional[Dict[str, str]] = None,
        task_resolver: Optional[TaskResolverMixin] = None,
        secret_requests: Optional[List[Secret]] = None,
        **kwargs,
    ):
        """
        :param name: unique name for the task, usually the function's module and name.
        :param task_config: Configuration object for Task. Should be a unique type for that specific Task
        :param task_type: String task type to be associated with this Task
        :param container_image: String FQN for the image.
        :param requests: custom resource request settings.
        :param limits: custom resource limit settings.
        :param environment: Environment variables you want the task to have when run.
        :param task_resolver: Custom resolver - will pick up the default resolver if empty, or the resolver set
          in the compilation context if one is set.
        :param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will
                                           be mounted based on the configuration in the Secret and available through
                                           the SecretManager using the name of the secret as the group
                                           Ideally the secret keys should also be semi-descriptive.
                                           The key values will be available from runtime, if the backend is configured
                       to provide secrets and if secrets are available in the configured secrets store.
                       Possible options for secret stores are
                        - `Vault <https://www.vaultproject.io/>`,
                        - `Confidant <https://lyft.github.io/confidant/>`,
                        - `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`
                        - `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`_
                        etc
        """
        sec_ctx = None
        if secret_requests:
            for s in secret_requests:
                if not isinstance(s, Secret):
                    raise AssertionError(
                        f"Secret {s} should be of type flytekit.Secret, received {type(s)}"
                    )
            sec_ctx = SecurityContext(secrets=secret_requests)
        super().__init__(
            task_type=task_type,
            name=name,
            task_config=task_config,
            security_ctx=sec_ctx,
            **kwargs,
        )
        self._container_image = container_image
        # TODO(katrogan): Implement resource overrides
        self._resources = ResourceSpec(
            requests=requests if requests else Resources(),
            limits=limits if limits else Resources())
        self._environment = environment

        compilation_state = FlyteContext.current_context().compilation_state
        if compilation_state and compilation_state.task_resolver:
            if task_resolver:
                logger.info(
                    f"Not using the passed in task resolver {task_resolver} because one found in compilation context"
                )
            self._task_resolver = compilation_state.task_resolver
            if self._task_resolver.task_name(self) is not None:
                self._name = self._task_resolver.task_name(self)
        else:
            self._task_resolver = task_resolver or default_task_resolver
示例#26
0
 def __enter__(self):
     logger.info("Entering timed context: {}".format(
         self._context_statement))
     self._start_wall_time = _time.perf_counter()
     self._start_process_time = _time.process_time()
示例#27
0
 def get_protocol(path: typing.Optional[str] = None):
     if path:
         return DataPersistencePlugins.get_protocol(path)
     logger.info("Setting protocol to file")
     return "file"
示例#28
0
"""

from flytekit.configuration.internal import LocalSDK
from flytekit.loggers import logger

from .basic_dfs import (
    ArrowToParquetEncodingHandler,
    PandasToParquetEncodingHandler,
    ParquetToArrowDecodingHandler,
    ParquetToPandasDecodingHandler,
)
from .structured_dataset import (
    StructuredDataset,
    StructuredDatasetDecoder,
    StructuredDatasetEncoder,
    StructuredDatasetTransformerEngine,
)

try:
    from .bigquery import (
        ArrowToBQEncodingHandlers,
        BQToArrowDecodingHandler,
        BQToPandasDecodingHandler,
        PandasToBQEncodingHandlers,
    )
except ImportError:
    logger.info(
        "We won't register bigquery handler for structured dataset because "
        "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery"
    )
示例#29
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This method translates Flyte's Type system based input values and invokes the actual call to the executor
        This method is also invoked during runtime.

        * ``VoidPromise`` is returned in the case when the task itself declares no outputs.
        * ``Literal Map`` is returned when the task returns either one more outputs in the declaration. Individual outputs
          may be none
        * ``DynamicJobSpec`` is returned when a dynamic workflow is executed
        """

        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with ctx.new_execution_context(
                mode=ctx.execution_state.mode,
                execution_params=new_user_params,
                working_dir=ctx.execution_state.working_dir,
        ) as exec_ctx:
            # TODO We could support default values here too - but not part of the plan right now
            # Translate the input literals to Python native
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, self.python_interface.inputs)

            # TODO: Logger should auto inject the current context information to indicate if the task is running within
            #   a workflow or a subworkflow etc
            logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.info(
                f"Task executed successfully in user level, outputs: {native_outputs}"
            )
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(self._outputs_interface.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                if self.python_interface.output_tuple_name and isinstance(
                        native_outputs, tuple):
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs[0]
                    }
                else:
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs
                    }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self._outputs_interface[k].type
                py_type = self.get_type_for_output_var(k, v)

                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
示例#30
0
    def to_python_value(
        self,
        ctx: FlyteContext,
        lv: Literal,
        expected_python_type: Type[GreatExpectationsType],
    ) -> GreatExpectationsType:
        if not (lv and lv.scalar and
                ((lv.scalar.primitive and lv.scalar.primitive.string_value)
                 or lv.scalar.schema or lv.scalar.blob
                 or lv.scalar.structured_dataset)):
            raise AssertionError(
                "Can only validate a literal string/FlyteFile/FlyteSchema value"
            )

        # fetch the configuration
        type_conf = GreatExpectationsTypeTransformer.get_config(
            expected_python_type)
        conf_dict = type_conf[1].to_dict()  # type: ignore

        ge_conf = GreatExpectationsFlyteConfig(**conf_dict)

        # fetch the data context
        context = ge.data_context.DataContext(
            ge_conf.context_root_dir)  # type: ignore

        # determine the type of data connector
        selected_datasource = list(
            filter(lambda x: x["name"] == ge_conf.datasource_name,
                   context.list_datasources()))

        if not selected_datasource:
            raise ValueError("Datasource doesn't exist!")

        data_connector_class_lookup = {
            data_connector_name: data_connector_class["class_name"]
            for data_connector_name, data_connector_class in
            selected_datasource[0]["data_connectors"].items()
        }

        specified_data_connector_class = data_connector_class_lookup[
            ge_conf.data_connector_name]

        is_runtime = False
        if specified_data_connector_class == "RuntimeDataConnector":
            is_runtime = True
            if not ge_conf.data_asset_name:
                raise ValueError(
                    "data_asset_name has to be given in a RuntimeBatchRequest")

        # file path for FlyteSchema and FlyteFile
        temp_dataset = ""

        # return value
        return_dataset = ""

        # FlyteSchema
        if lv.scalar.schema or lv.scalar.structured_dataset:
            return_dataset, temp_dataset = self._flyte_schema(
                is_runtime=is_runtime,
                ctx=ctx,
                ge_conf=ge_conf,
                lv=lv,
                expected_python_type=type_conf[0])

        # FlyteFile
        if lv.scalar.blob:
            return_dataset, temp_dataset = self._flyte_file(
                ctx=ctx,
                ge_conf=ge_conf,
                lv=lv,
                expected_python_type=type_conf[0])

        if lv.scalar.primitive:
            dataset = return_dataset = lv.scalar.primitive.string_value
        else:
            dataset = temp_dataset

        batch_request_conf = ge_conf.batch_request_config

        # minimalistic batch request
        final_batch_request = {
            "data_asset_name":
            ge_conf.data_asset_name if is_runtime else dataset,
            "datasource_name": ge_conf.datasource_name,
            "data_connector_name": ge_conf.data_connector_name,
        }

        # Great Expectations' RuntimeBatchRequest
        if batch_request_conf and (batch_request_conf["runtime_parameters"]
                                   or is_runtime):
            final_batch_request.update({
                "runtime_parameters":
                batch_request_conf["runtime_parameters"]
                if batch_request_conf["runtime_parameters"] else {},
                "batch_identifiers":
                batch_request_conf["batch_identifiers"],
                "batch_spec_passthrough":
                batch_request_conf["batch_spec_passthrough"],
            })

            if is_runtime and lv.scalar.primitive:
                final_batch_request["runtime_parameters"]["query"] = dataset
            elif is_runtime and (lv.scalar.schema
                                 or lv.scalar.structured_dataset):
                final_batch_request["runtime_parameters"][
                    "batch_data"] = return_dataset.open().all()
            else:
                raise AssertionError(
                    "Can only use runtime_parameters for query(str)/schema data"
                )

        # Great Expectations' BatchRequest
        elif batch_request_conf:
            final_batch_request.update({
                "data_connector_query":
                batch_request_conf["data_connector_query"],
                "batch_spec_passthrough":
                batch_request_conf["batch_spec_passthrough"],
            })

        if ge_conf.checkpoint_params:
            checkpoint = SimpleCheckpoint(
                f"_tmp_checkpoint_{ge_conf.expectation_suite_name}",
                context,
                **ge_conf.checkpoint_params,
            )
        else:
            checkpoint = SimpleCheckpoint(
                f"_tmp_checkpoint_{ge_conf.expectation_suite_name}", context)

        # identify every run uniquely
        run_id = RunIdentifier(
            **{
                "run_name": ge_conf.datasource_name + "_run",
                "run_time": datetime.datetime.utcnow(),
            })

        checkpoint_result = checkpoint.run(
            run_id=run_id,
            validations=[{
                "batch_request":
                final_batch_request,
                "expectation_suite_name":
                ge_conf.expectation_suite_name,
            }],
        )
        final_result = convert_to_json_serializable(
            checkpoint_result.list_validation_results())[0]

        result_string = ""
        if final_result["success"] is False:
            for every_result in final_result["results"]:
                if every_result["success"] is False:
                    result_string += (
                        every_result["expectation_config"]["kwargs"]["column"]
                        + " -> " +
                        every_result["expectation_config"]["expectation_type"]
                        + "\n")

            # raise a Great Expectations' exception
            raise ValidationError(
                "Validation failed!\nCOLUMN\t\tFAILED EXPECTATION\n" +
                result_string)

        logger.info("Validation succeeded!")

        return typing.cast(GreatExpectationsType, return_dataset)