Exemplo n.º 1
0
 def guess_python_type(cls, flyte_type: LiteralType) -> type:
     for _, transformer in cls._REGISTRY.items():
         try:
             return transformer.guess_python_type(flyte_type)
         except ValueError:
             logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")
     raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}")
Exemplo n.º 2
0
    def add_workflow_output(
        self, output_name: str, p: Union[Promise, List[Promise], Dict[str, Promise]], python_type: Optional[Type] = None
    ):
        """
        Add an output with the given name from the given node output.
        """
        if output_name in self._python_interface.outputs:
            raise FlyteValidationException(f"Output {output_name} already exists in workflow {self.name}")

        if python_type is None:
            if type(p) == list or type(p) == dict:
                raise FlyteValidationException(
                    f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}"
                    f" starting with the container type (e.g. List[int]"
                )
            python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var]
            logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}")

        flyte_type = TypeEngine.to_literal_type(python_type=python_type)

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            raise Exception("Can't already be compiling")
        with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
            b = binding_from_python_std(
                ctx, output_name, expected_literal_type=flyte_type, t_value=p, t_value_type=python_type
            )
            self._output_bindings.append(b)
            self._python_interface = self._python_interface.with_outputs(extra_outputs={output_name: python_type})
            self._interface = transform_interface_to_typed_interface(self._python_interface)
Exemplo n.º 3
0
    def find_lhs(self) -> str:
        if self._lhs is not None:
            return self._lhs

        if self._instantiated_in is None or self._instantiated_in == "":
            raise _system_exceptions.FlyteSystemException(f"Object {self} does not have an _instantiated in")

        logger.debug(f"Looking for LHS for {self} from {self._instantiated_in}")
        m = _importlib.import_module(self._instantiated_in)
        for k in dir(m):
            try:
                if getattr(m, k) is self:
                    logger.debug(f"Found LHS for {self}, {k}")
                    self._lhs = k
                    return k
            except ValueError as err:
                # Empty pandas dataframes behave weirdly here such that calling `m.df` raises:
                # ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(),
                #   a.any() or a.all()
                # Since dataframes aren't registrable entities to begin with we swallow any errors they raise and
                # continue looping through m.
                logger.warning("Caught ValueError {} while attempting to auto-assign name".format(err))
                pass

        logger.error(f"Could not find LHS for {self} in {self._instantiated_in}")
        raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}")
Exemplo n.º 4
0
def get_config_file(c: typing.Union[str, ConfigFile, None]) -> typing.Optional[ConfigFile]:
    """
    Checks if the given argument is a file or a configFile and returns a loaded configFile else returns None
    """
    if c is None:
        # See if there's a config file in the current directory where Python is being run from
        current_location_config = Path("flytekit.config")
        if current_location_config.exists():
            logger.info(f"Using configuration from Python process root {current_location_config.absolute()}")
            return ConfigFile(str(current_location_config.absolute()))

        # If not, see if there's a config in the user's home directory
        home_dir_config = Path(Path.home(), ".flyte", "config")  # _default_config_file_name in main.py
        if home_dir_config.exists():
            logger.info(f"Using configuration from home directory {home_dir_config.absolute()}")
            return ConfigFile(str(home_dir_config.absolute()))

        # If not, see if the env var that flytectl sandbox tells the user to set is set,
        # or see if there's something in the default home directory location
        flytectl_path = Path(Path.home(), ".flyte", "config.yaml")
        flytectl_path_from_env = getenv(FLYTECTL_CONFIG_ENV_VAR, None)

        if flytectl_path_from_env:
            flytectl_path = Path(flytectl_path_from_env)
        if flytectl_path.exists():
            logger.info(f"Using flytectl/YAML config {flytectl_path.absolute()}")
            return ConfigFile(str(flytectl_path.absolute()))

        # If not, then return None and let caller handle
        return None
    if isinstance(c, str):
        logger.debug(f"Using specified config file at {c}")
        return ConfigFile(c)
    return c
Exemplo n.º 5
0
    def register(cls,
                 h: Handlers,
                 default_for_type: Optional[bool] = True,
                 override: Optional[bool] = False):
        """
        Call this with any handler to register it with this dataframe meta-transformer

        The string "://" should not be present in any handler's protocol so we don't check for it.
        """
        lowest_level = cls._handler_finder(h)
        if h.supported_format in lowest_level and override is False:
            raise ValueError(
                f"Already registered a handler for {(h.python_type, h.protocol, h.supported_format)}"
            )
        lowest_level[h.supported_format] = h
        logger.debug(
            f"Registered {h} as handler for {h.python_type}, protocol {h.protocol}, fmt {h.supported_format}"
        )

        if default_for_type:
            # TODO: Add logging, think about better ux, maybe default False and warn if doesn't exist.
            cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
            cls.DEFAULT_PROTOCOLS[h.python_type] = h.protocol

        # Register with the type engine as well
        # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as
        # long as the older Pandas/FlyteSchema transformer do not also specify the override
        engine = StructuredDatasetTransformerEngine()
        TypeEngine.register_additional_type(engine,
                                            h.python_type,
                                            override=True)
Exemplo n.º 6
0
    def __call__(self, *args, **kwargs):
        # When a Task is () aka __called__, there are three things we may do:
        #  a. Plain execution Mode - just run the execute function. If not overridden, we should raise an exception
        #  b. Compilation Mode - this happens when the function is called as part of a workflow (potentially
        #     dynamic task). Produce promise objects and create a node.
        #  c. Workflow Execution Mode - when a workflow is being run locally. Even though workflows are functions
        #     and everything should be able to be passed through naturally, we'll want to wrap output values of the
        #     function into objects, so that potential .with_cpu or other ancillary functions can be attached to do
        #     nothing. Subsequent tasks will have to know how to unwrap these. If by chance a non-Flyte task uses a
        #     task output as an input, things probably will fail pretty obviously.
        #     Since this is a reference entity, it still needs to be mocked otherwise an exception will be raised.
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                f"Cannot call reference entity with args - detected {len(args)} positional args {args}"
            )

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
            return self.compile(ctx, *args, **kwargs)
        elif (ctx.execution_state is not None and ctx.execution_state.mode
              == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION):
            if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
                return
            return self._local_execute(ctx, **kwargs)
        else:
            logger.debug("Reference entity - running raw execute")
            return self.execute(**kwargs)
Exemplo n.º 7
0
 def _get_from_yaml(self, c: YamlConfigEntry) -> typing.Any:
     keys = c.switch.split(".")  # flytectl switches are dot delimited
     d = self.yaml_config
     try:
         for k in keys:
             d = d[k]
         return d
     except KeyError:
         logger.debug(f"Switch {c.switch} could not be found in yaml config")
         return None
Exemplo n.º 8
0
 def pop_context() -> FlyteContext:
     context_list = flyte_context_Var.get()
     ctx = context_list.pop()
     flyte_context_Var.set(context_list)
     t = "\t"
     logger.debug(
         f"{t * ctx.level}[{len(flyte_context_Var.get()) + 1}] Popping context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}"
     )
     if len(flyte_context_Var.get()) == 0:
         raise AssertionError(f"Illegal Context state! Popped, {ctx}")
     return ctx
Exemplo n.º 9
0
 def exists(self, path: str) -> bool:
     try:
         fs = self.get_filesystem(path)
         return fs.exists(path)
     except OSError as oe:
         logger.debug(f"Error in exists checking {path} {oe}")
         fs = self.get_anonymous_filesystem(path)
         if fs is not None:
             logger.debug(
                 "S3 source detected, attempting anonymous S3 exists check")
             return fs.exists(path)
         raise oe
Exemplo n.º 10
0
def read_file_if_exists(filename: typing.Optional[str], encoding=None) -> typing.Optional[str]:
    """
    Reads the contents of the file if passed a path. Otherwise, returns None.

    :param filename: The file path to load
    :param encoding: The encoding to use when reading the file.
    :return: The contents of the file as a string or None.
    """
    if not filename:
        return None

    filename = pathlib.Path(filename)
    logger.debug(f"Reading file contents from [{filename}] with current directory [{os.getcwd()}].")
    return filename.read_text(encoding=encoding)
Exemplo n.º 11
0
 def push_context(
         ctx: FlyteContext,
         f: Optional[traceback.FrameSummary] = None) -> FlyteContext:
     if not f:
         f = FlyteContextManager.get_origin_stackframe(limit=2)
     ctx.set_stackframe(f)
     context_list = flyte_context_Var.get()
     context_list.append(ctx)
     flyte_context_Var.set(context_list)
     t = "\t"
     logger.debug(
         f"{t * ctx.level}[{len(flyte_context_Var.get())}] Pushing context - {'compile' if ctx.compilation_state else 'execute'}, branch[{ctx.in_a_condition}], {ctx.get_origin_stackframe_repr()}"
     )
     return ctx
Exemplo n.º 12
0
 def get(self, from_path: str, to_path: str, recursive: bool = False):
     if recursive:
         raise user.FlyteAssertion("Reading data recursively from HTTP endpoint is not currently supported.")
     rsp = requests.get(from_path)
     if rsp.status_code != self._HTTP_OK:
         raise user.FlyteValueException(
             rsp.status_code,
             "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK),
         )
     head, _ = os.path.split(to_path)
     if head and head.startswith("/"):
         logger.debug(f"HttpPersistence creating {head} so that parent dirs exist")
         pathlib.Path(head).mkdir(parents=True, exist_ok=True)
     with open(to_path, "wb") as writer:
         writer.write(rsp.content)
Exemplo n.º 13
0
 def get(self, from_path: str, to_path: str, recursive: bool = False):
     fs = self.get_filesystem(from_path)
     if recursive:
         from_path, to_path = self.recursive_paths(from_path, to_path)
     try:
         return fs.get(from_path, to_path, recursive=recursive)
     except OSError as oe:
         logger.debug(
             f"Error in getting {from_path} to {to_path} rec {recursive} {oe}"
         )
         fs = self.get_anonymous_filesystem(from_path)
         if fs is not None:
             logger.debug(
                 "S3 source detected, attempting anonymous S3 access")
             return fs.get(from_path, to_path, recursive=recursive)
         raise oe
Exemplo n.º 14
0
def _update_cmd_config_and_execute(s3_cfg: S3Config, cmd: List[str]):
    env = _os.environ.copy()

    if s3_cfg.enable_debug:
        cmd.insert(1, "--debug")

    if s3_cfg.endpoint is not None:
        cmd.insert(1, s3_cfg.endpoint)
        cmd.insert(1, "--endpoint-url")

    if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ:
        if s3_cfg.access_key_id:
            env[S3_ACCESS_KEY_ID_ENV_NAME] = s3_cfg.access_key_id

    if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ:
        if s3_cfg.secret_access_key:
            env[S3_SECRET_ACCESS_KEY_ENV_NAME] = s3_cfg.secret_access_key

    retry = 0
    while True:
        try:
            try:
                return subprocess.check_call(cmd, env=env)
            except Exception as e:
                if retry > 0:
                    logger.info(
                        f"AWS command failed with error {e}, command: {cmd}, retry {retry}"
                    )

            logger.debug(
                f"Appending anonymous flag and retrying command {cmd}")
            anonymous_cmd = cmd[:]  # strings only, so this is deep enough
            anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG)
            return subprocess.check_call(anonymous_cmd, env=env)

        except Exception as e:
            logger.error(
                f"Exception when trying to execute {cmd}, reason: {str(e)}")
            retry += 1
            if retry > s3_cfg.retries:
                raise
            secs = s3_cfg.backoff
            logger.info(
                f"Sleeping before retrying again, after {secs.total_seconds()} seconds"
            )
            time.sleep(secs.total_seconds())
            logger.info("Retrying again")
Exemplo n.º 15
0
    def unwrap_literal_map_and_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> _literal_models.LiteralMap:
        """
        Please see the implementation of the dispatch_execute function in the real task.
        """

        # Invoked before the task is executed
        # Translate the input literals to Python native
        native_inputs = TypeEngine.literal_map_to_kwargs(
            ctx, input_literal_map, self.python_interface.inputs)

        logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
        try:
            native_outputs = self.execute(**native_inputs)
        except Exception as e:
            logger.exception(f"Exception when executing {e}")
            raise e
        logger.debug("Task executed successfully in user level")

        expected_output_names = list(self.python_interface.outputs.keys())
        if len(expected_output_names) == 1:
            native_outputs_as_map = {expected_output_names[0]: native_outputs}
        elif len(expected_output_names) == 0:
            native_outputs_as_map = {}
        else:
            native_outputs_as_map = {
                expected_output_names[i]: native_outputs[i]
                for i, _ in enumerate(native_outputs)
            }

        # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
        # built into the IDL that all the values of a literal map are of the same type.
        literals = {}
        for k, v in native_outputs_as_map.items():
            literal_type = self.interface.outputs[k].type
            py_type = self.python_interface.outputs[k]
            if isinstance(v, tuple):
                raise AssertionError(
                    f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}"
                )
            literals[k] = TypeEngine.to_literal(ctx, v, py_type, literal_type)
        outputs_literal_map = _literal_models.LiteralMap(literals=literals)
        # After the execute has been successfully completed
        return outputs_literal_map
Exemplo n.º 16
0
def load_packages_and_modules(
    ss: SerializationSettings,
    project_root: Path,
    pkgs_or_mods: typing.List[str],
    options: typing.Optional[Options] = None,
) -> typing.List[RegistrableEntity]:
    """
    The project root is added as the first entry to sys.path, and then all the specified packages and modules
    given are loaded with all submodules. The reason for prepending the entry is to ensure that the name that
    the various modules are loaded under are the fully-resolved name.

    For example, using flytesnacks cookbook, if you are in core/ and you call this function with
    ``flyte_basics/hello_world.py control_flow/``, the ``hello_world`` module would be loaded
    as ``core.flyte_basics.hello_world`` even though you're already in the core/ folder.

    :param ss:
    :param project_root:
    :param pkgs_or_mods:
    :param options:
    :return: The common detected root path, the output of _find_project_root
    """

    pkgs_and_modules = []
    for pm in pkgs_or_mods:
        p = Path(pm).resolve()
        rel_path_from_root = p.relative_to(project_root)
        # One day we should learn how to do this right. This is not the right way to load a python module
        # from a file. See pydoc.importfile for inspiration
        dot_delineated = os.path.splitext(rel_path_from_root)[0].replace(
            os.path.sep, ".")  # noqa

        logger.debug(
            f"User specified arg {pm} has {str(rel_path_from_root)} relative path loading it as {dot_delineated}"
        )
        pkgs_and_modules.append(dot_delineated)

    registrable_entities = serialize(pkgs_and_modules, ss, str(project_root),
                                     options)

    return registrable_entities
Exemplo n.º 17
0
def find_common_root(
    pkgs_or_mods: typing.Union[typing.Tuple[str], typing.List[str]], ) -> Path:
    """
    Given an arbitrary list of folders and files, this function will use the script mode function to walk up
    the filesystem to find the first folder without an init file. If all the folders and files resolve to
    the same root folder, then that Path is returned. Otherwise an error is raised.

    :param pkgs_or_mods:
    :return: The common detected root path, the output of _find_project_root
    """
    project_root = None
    for pm in pkgs_or_mods:
        root = _find_project_root(pm)
        if project_root is None:
            project_root = root
        else:
            if project_root != root:
                raise ValueError(
                    f"Specified module {pm} has root {root} but {project_root} already specified"
                )

    logger.debug(f"Common root folder detected as {str(project_root)}")

    return project_root
Exemplo n.º 18
0
def register(
    registrable_entities: typing.List[RegistrableEntity],
    project: str,
    domain: str,
    version: str,
    client: SynchronousFlyteClient,
):
    # The incoming registrable entities are already in base protobuf form, not model form, so we use the
    # raw client's methods instead of the friendly client's methods by calling super
    for admin_entity in registrable_entities:
        try:
            if isinstance(admin_entity, _idl_admin_TaskSpec):
                ident, task_spec = hydrate_registration_parameters(
                    identifier_pb2.TASK, project, domain, version,
                    admin_entity)
                logger.debug(f"Creating task {ident}")
                super(SynchronousFlyteClient, client).create_task(
                    TaskCreateRequest(id=ident, spec=task_spec))
            elif isinstance(admin_entity, _idl_admin_WorkflowSpec):
                ident, wf_spec = hydrate_registration_parameters(
                    identifier_pb2.WORKFLOW, project, domain, version,
                    admin_entity)
                logger.debug(f"Creating workflow {ident}")
                super(SynchronousFlyteClient, client).create_workflow(
                    WorkflowCreateRequest(id=ident, spec=wf_spec))
            elif isinstance(admin_entity, _idl_admin_LaunchPlan):
                ident, admin_lp = hydrate_registration_parameters(
                    identifier_pb2.LAUNCH_PLAN, project, domain, version,
                    admin_entity)
                logger.debug(f"Creating launch plan {ident}")
                super(SynchronousFlyteClient, client).create_launch_plan(
                    LaunchPlanCreateRequest(id=ident, spec=admin_lp.spec))
            else:
                raise AssertionError(
                    f"Unknown entity of type {type(admin_entity)}")
        except FlyteEntityAlreadyExistsException:
            logger.info(f"{admin_entity} already exists")
        except Exception as e:
            logger.info(
                f"Failed to register entity {admin_entity} with error {e}")
            raise e
Exemplo n.º 19
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This function is largely similar to the base PythonTask, with the exception that we have to infer the Python
        interface before executing. Also, we refer to ``self.task_template`` rather than just ``self`` similar to task
        classes that derive from the base ``PythonTask``.
        """
        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        user_space_params=new_user_params))) as exec_ctx:
            # Added: Have to reverse the Python interface from the task template Flyte interface
            # See docstring for more details.
            guessed_python_input_types = TypeEngine.guess_python_types(
                self.task_template.interface.inputs)
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, guessed_python_input_types)

            logger.info(
                f"Invoking FlyteTask executor {self.task_template.id.name} with inputs: {native_inputs}"
            )
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.debug("Task executed successfully in user level")
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(
                self.task_template.interface.outputs.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                # Deleted some stuff
                native_outputs_as_map = {
                    expected_output_names[0]: native_outputs
                }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self.task_template.interface.outputs[k].type
                py_type = type(v)

                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.task_template.id.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
Exemplo n.º 20
0
    def compile(self, **kwargs):
        """
        Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
        a 'closure' in the traditional sense of the word.
        """
        ctx = FlyteContextManager.current_context()
        self._input_parameters = transform_inputs_to_parameters(
            ctx, self.python_interface)
        all_nodes = []
        prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else ""

        with FlyteContextManager.with_context(
                ctx.with_compilation_state(
                    CompilationState(prefix=prefix,
                                     task_resolver=self))) as comp_ctx:
            # Construct the default input promise bindings, but then override with the provided inputs, if any
            input_kwargs = construct_input_promises(
                [k for k in self.interface.inputs.keys()])
            input_kwargs.update(kwargs)
            workflow_outputs = exception_scopes.user_entry_point(
                self._workflow_function)(**input_kwargs)
            all_nodes.extend(comp_ctx.compilation_state.nodes)

            # This little loop was added as part of the task resolver change. The task resolver interface itself is
            # more or less stateless (the future-proofing get_all_tasks function notwithstanding). However the
            # implementation of the TaskResolverMixin that this workflow class inherits from (ClassStorageTaskResolver)
            # does store state. This loop adds Tasks that are defined within the body of the workflow to the workflow
            # object itself.
            for n in comp_ctx.compilation_state.nodes:
                if isinstance(n.flyte_entity, PythonAutoContainerTask
                              ) and n.flyte_entity.task_resolver == self:
                    logger.debug(
                        f"WF {self.name} saving task {n.flyte_entity.name}")
                    self.add(n.flyte_entity)

        # Iterate through the workflow outputs
        bindings = []
        output_names = list(self.interface.outputs.keys())
        # The reason the length 1 case is separate is because the one output might be a list. We don't want to
        # iterate through the list here, instead we should let the binding creation unwrap it and make a binding
        # collection/map out of it.
        if len(output_names) == 1:
            if isinstance(workflow_outputs, tuple):
                if len(workflow_outputs) != 1:
                    raise AssertionError(
                        f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}"
                    )
                if self.python_interface.output_tuple_name is None:
                    raise AssertionError(
                        "Outputs specification for Workflow does not define a tuple, but return value is a tuple"
                    )
                workflow_outputs = workflow_outputs[0]
            t = self.python_interface.outputs[output_names[0]]
            b = binding_from_python_std(
                ctx,
                output_names[0],
                self.interface.outputs[output_names[0]].type,
                workflow_outputs,
                t,
            )
            bindings.append(b)
        elif len(output_names) > 1:
            if not isinstance(workflow_outputs, tuple):
                raise AssertionError(
                    "The Workflow specification indicates multiple return values, received only one"
                )
            if len(output_names) != len(workflow_outputs):
                raise Exception(
                    f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}"
                )
            for i, out in enumerate(output_names):
                if isinstance(workflow_outputs[i], ConditionalSection):
                    raise AssertionError(
                        "A Conditional block (if-else) should always end with an `else_()` clause"
                    )
                t = self.python_interface.outputs[out]
                b = binding_from_python_std(
                    ctx,
                    out,
                    self.interface.outputs[out].type,
                    workflow_outputs[i],
                    t,
                )
                bindings.append(b)

        # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain
        self._nodes = all_nodes
        self._output_bindings = bindings
        if not output_names:
            return None
        if len(output_names) == 1:
            return bindings[0]
        return tuple(bindings)
Exemplo n.º 21
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This method translates Flyte's Type system based input values and invokes the actual call to the executor
        This method is also invoked during runtime.

        * ``VoidPromise`` is returned in the case when the task itself declares no outputs.
        * ``Literal Map`` is returned when the task returns either one more outputs in the declaration. Individual outputs
          may be none
        * ``DynamicJobSpec`` is returned when a dynamic workflow is executed
        """

        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)
        from flytekit.deck.deck import _output_deck

        # Create another execution context with the new user params, but let's keep the same working dir
        with FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        user_space_params=new_user_params))
                # type: ignore
        ) as exec_ctx:
            # TODO We could support default values here too - but not part of the plan right now
            # Translate the input literals to Python native
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, self.python_interface.inputs)

            # TODO: Logger should auto inject the current context information to indicate if the task is running within
            #   a workflow or a subworkflow etc
            logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.debug("Task executed successfully in user level")
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(self._outputs_interface.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                if self.python_interface.output_tuple_name and isinstance(
                        native_outputs, tuple):
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs[0]
                    }
                else:
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs
                    }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self._outputs_interface[k].type
                py_type = self.get_type_for_output_var(k, v)

                if isinstance(v, tuple):
                    raise TypeError(
                        f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    logger.error(
                        f"Failed to convert return value for var {k} with error {type(e)}: {e}"
                    )
                    raise TypeError(
                        f"Failed to convert return value for var {k} for function {self.name} with error {type(e)}: {e}"
                    ) from e

            INPUT = "input"
            OUTPUT = "output"

            input_deck = Deck(INPUT)
            for k, v in native_inputs.items():
                input_deck.append(
                    TypeEngine.to_html(ctx, v,
                                       self.get_type_for_input_var(k, v)))

            output_deck = Deck(OUTPUT)
            for k, v in native_outputs_as_map.items():
                output_deck.append(
                    TypeEngine.to_html(ctx, v,
                                       self.get_type_for_output_var(k, v)))

            if _internal.Deck.DISABLE_DECK.read(
            ) is not True and self.disable_deck is False:
                _output_deck(self.name.split(".")[-1], new_user_params)

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
Exemplo n.º 22
0
def extract_return_annotation(
        return_annotation: Union[Type, Tuple]) -> Dict[str, Type]:
    """
    The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
    name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.

        # Option 1
        nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)
        def t(a: int, b: str) -> nt1: ...

        # Option 2
        def t(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): ...

        # Option 3
        def t(a: int, b: str) -> typing.Tuple[int, str]: ...

        # Option 4
        def t(a: int, b: str) -> (int, str): ...

        # Option 5
        def t(a: int, b: str) -> str: ...

        # Option 6
        def t(a: int, b: str) -> None: ...

        # Options 7/8
        def t(a: int, b: str) -> List[int]: ...
        def t(a: int, b: str) -> Dict[str, int]: ...

    Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
    definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
    """

    # Handle Option 6
    # We can think about whether we should add a default output name with type None in the future.
    if return_annotation is None or return_annotation is inspect.Signature.empty:
        return {}

    # This statement results in true for typing.Namedtuple, single and void return types, so this
    # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
    if isinstance(return_annotation, Type) or isinstance(
            return_annotation, TypeVar):
        # isinstance / issubclass does not work for Namedtuple.
        # Options 1 and 2
        if hasattr(return_annotation, "_field_types"):
            logger.debug(f"Task returns named tuple {return_annotation}")
            return return_annotation._field_types

    if hasattr(return_annotation,
               "__origin__") and return_annotation.__origin__ is tuple:
        # Handle option 3
        logger.debug(f"Task returns unnamed typing.Tuple {return_annotation}")
        if len(return_annotation.__args__) == 1:
            raise FlyteValidationException(
                "Tuples should be used to indicate multiple return values, found only one return variable."
            )
        return OrderedDict(
            zip(list(output_name_generator(len(return_annotation.__args__))),
                return_annotation.__args__))
    elif isinstance(return_annotation, tuple):
        if len(return_annotation) == 1:
            raise FlyteValidationException(
                "Please don't use a tuple if you're just returning one thing.")
        return OrderedDict(
            zip(list(output_name_generator(len(return_annotation))),
                return_annotation))

    else:
        # Handle all other single return types
        logger.debug(f"Task returns unnamed native tuple {return_annotation}")
        return {default_output_name(): return_annotation}