Exemplo n.º 1
0
    def __init__(
        self,
        local_path: os.PathLike = None,
        remote_path: str = None,
        supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE,
        downloader: typing.Callable[[str, os.PathLike], None] = None,
    ):

        if supported_mode == SchemaOpenMode.READ and remote_path is None:
            raise ValueError(
                "To create a FlyteSchema in read mode, remote_path is required"
            )
        if (supported_mode == SchemaOpenMode.WRITE and local_path is None
                and FlyteContext.current_context().file_access is None):
            raise ValueError(
                "To create a FlyteSchema in write mode, local_path is required"
            )

        if local_path is None:
            local_path = FlyteContext.current_context(
            ).file_access.get_random_local_directory()
        self._local_path = local_path
        self._remote_path = remote_path
        self._supported_mode = supported_mode
        # This is a special attribute that indicates if the data was either downloaded or uploaded
        self._downloaded = False
        self._downloader = downloader
Exemplo n.º 2
0
def test_levels():
    s = SampleTestClass(value=1)
    with FlyteContext(flyte_client=s) as ctx:
        assert ctx.flyte_client.value == 1
        with FlyteContext(flyte_client=SampleTestClass(value=2)) as ctx:
            assert ctx.flyte_client.value == 2

        with FlyteContext(compilation_state=CompilationState(
                prefix="")) as ctx:
            assert ctx.flyte_client.value == 1
Exemplo n.º 3
0
    def __call__(self, *args, **kwargs):
        # When a Task is () aka __called__, there are three things we may do:
        #  a. Plain execution Mode - just run the execute function. If not overridden, we should raise an exception
        #  b. Compilation Mode - this happens when the function is called as part of a workflow (potentially
        #     dynamic task). Produce promise objects and create a node.
        #  c. Workflow Execution Mode - when a workflow is being run locally. Even though workflows are functions
        #     and everything should be able to be passed through naturally, we'll want to wrap output values of the
        #     function into objects, so that potential .with_cpu or other ancillary functions can be attached to do
        #     nothing. Subsequent tasks will have to know how to unwrap these. If by chance a non-Flyte task uses a
        #     task output as an input, things probably will fail pretty obviously.
        #     Since this is a reference entity, it still needs to be mocked otherwise an exception will be raised.
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                f"Cannot call reference entity with args - detected {len(args)} positional args {args}"
            )

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
            return self.compile(ctx, *args, **kwargs)
        elif (
            ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
        ):
            if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
                return
            return self._local_execute(ctx, **kwargs)
        else:
            logger.debug("Reference entity - running raw execute")
            return self.execute(**kwargs)
Exemplo n.º 4
0
    def create(
        cls,
        name: str,
        workflow: _annotated_workflow.Workflow,
        default_inputs: Dict[str, Any] = None,
        fixed_inputs: Dict[str, Any] = None,
        schedule: _schedule_model.Schedule = None,
        notifications: List[_common_models.Notification] = None,
        auth_role: _common_models.AuthRole = None,
    ) -> LaunchPlan:
        ctx = FlyteContext.current_context()
        default_inputs = default_inputs or {}
        fixed_inputs = fixed_inputs or {}
        # Default inputs come from two places, the original signature of the workflow function, and the default_inputs
        # argument to this function. We'll take the latter as having higher precedence.
        wf_signature_parameters = transform_inputs_to_parameters(
            ctx, workflow._native_interface)

        # Construct a new Interface object with just the default inputs given to get Parameters, maybe there's an
        # easier way to do this, think about it later.
        temp_inputs = {}
        for k, v in default_inputs.items():
            temp_inputs[k] = (workflow._native_interface.inputs[k], v)
        temp_interface = Interface(inputs=temp_inputs, outputs={})
        temp_signature = transform_inputs_to_parameters(ctx, temp_interface)
        wf_signature_parameters._parameters.update(temp_signature.parameters)

        # These are fixed inputs that cannot change at launch time. If the same argument is also in default inputs,
        # it'll be taken out from defaults in the LaunchPlan constructor
        fixed_literals = translate_inputs_to_literals(
            ctx,
            input_kwargs=fixed_inputs,
            interface=workflow.interface,
            native_input_types=workflow._native_interface.inputs,
        )
        fixed_lm = _literal_models.LiteralMap(literals=fixed_literals)

        lp = cls(
            name=name,
            workflow=workflow,
            parameters=wf_signature_parameters,
            fixed_inputs=fixed_lm,
            schedule=schedule,
            notifications=notifications,
            auth_role=auth_role,
        )

        # This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out
        # to protobuf, but for local execution and such, why not save the original Python native values as well so
        # we don't have to reverse it back every time.
        default_inputs.update(fixed_inputs)
        lp._saved_inputs = default_inputs

        if name in cls.CACHE:
            raise AssertionError(
                f"Launch plan named {name} was already created! Make sure your names are unique."
            )
        cls.CACHE[name] = lp
        return lp
Exemplo n.º 5
0
    def __call__(self, *args, **kwargs):
        if len(args) > 0:
            raise AssertionError("Only Keyword Arguments are supported for Workflow executions")

        ctx = FlyteContext.current_context()

        # Handle subworkflows in compilation
        if ctx.compilation_state is not None:
            input_kwargs = self._native_interface.default_inputs_as_kwargs
            input_kwargs.update(kwargs)
            return create_and_link_node(ctx, entity=self, interface=self._native_interface, **input_kwargs)
        elif (
            ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
        ):
            # We are already in a local execution, just continue the execution context
            return self._local_execute(ctx, **kwargs)

        # When someone wants to run the workflow function locally. Assume that the inputs given are given as Python
        # native values. _local_execute will always translate Python native literals to Flyte literals, so no worries
        # there, but it'll return Promise objects.
        else:
            # Run some sanity checks
            # Even though the _local_execute call generally expects inputs to be Promises, we don't have to do the
            # conversion here in this loop. The reason is because we don't prevent users from specifying inputs
            # as direct scalars, which means there's another Promise-generating loop inside _local_execute too
            for k, v in kwargs.items():
                if k not in self.interface.inputs:
                    raise ValueError(f"Received unexpected keyword argument {k}")
                if isinstance(v, Promise):
                    raise ValueError(f"Received a promise for a workflow call, when expecting a native value for {k}")

            with ctx.new_execution_context(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) as ctx:
                result = self._local_execute(ctx, **kwargs)

            expected_outputs = len(self._native_interface.outputs)
            if expected_outputs == 0:
                if result is None or isinstance(result, VoidPromise):
                    return None
                else:
                    raise Exception(f"Workflow local execution expected 0 outputs but something received {result}")

            if (expected_outputs > 1 and len(result) == expected_outputs) or (
                expected_outputs == 1 and result is not None
            ):
                if isinstance(result, Promise):
                    v = [v for k, v in self._native_interface.outputs.items()][0]
                    return TypeEngine.to_python_value(ctx, result.val, v)
                else:
                    for prom in result:
                        if not isinstance(prom, Promise):
                            raise Exception("should be promises")
                        native_list = [
                            TypeEngine.to_python_value(ctx, promise.val, self._native_interface.outputs[promise.var])
                            for promise in result
                        ]
                        return tuple(native_list)

            raise ValueError("expected outputs and actual outputs do not match")
Exemplo n.º 6
0
    def end_branch(self) -> Union[Condition, Promise]:
        """
        This should be invoked after every branch has been visited
        """
        ctx = FlyteContext.current_context()
        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            """
            In case of Local workflow execution, we should first mark the branch as complete, then
            Then we first check for if this is the last case,
            In case this is the last case, we return the output from the selected case - A case should always
            be selected (see start_branch)
            If this is not the last case, we should return the condition so that further chaining can be done
            """
            # Let us mark the execution state as complete
            ctx.execution_state.branch_complete()
            if self._last_case:
                ctx.execution_state.exit_conditional_section()
                if self._selected_case.output_promise is None and self._selected_case.err is None:
                    raise AssertionError("Bad conditional statements, did not resolve in a promise")
                elif self._selected_case.output_promise is not None:
                    return self._selected_case.output_promise
                raise ValueError(self._selected_case.err)
            return self._condition
        elif ctx.compilation_state:
            ########
            # COMPILATION MODE
            """
            In case this is not local workflow execution then, we should check if this is the last case.
            If so then return the promise, else return the condition
            """
            if self._last_case:
                ctx.compilation_state.exit_conditional_section()
                # branch_nodes = ctx.compilation_state.nodes
                node, promises = to_branch_node(self._name, self)
                # Verify branch_nodes == nodes in bn
                bindings: typing.List[Binding] = []
                upstream_nodes = set()
                for p in promises:
                    if not p.is_ready:
                        bindings.append(Binding(var=p.var, binding=BindingData(promise=p.ref)))
                        upstream_nodes.add(p.ref.node)

                n = Node(
                    id=f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}",
                    metadata=_core_wf.NodeMetadata(self._name, timeout=datetime.timedelta(), retries=RetryStrategy(0)),
                    bindings=sorted(bindings, key=lambda b: b.var),
                    upstream_nodes=list(upstream_nodes),  # type: ignore
                    flyte_entity=node,
                )
                ctx.compilation_state.add_node(n)
                return self._compute_outputs(n)
            return self._condition

        raise AssertionError("Branches can only be invoked within a workflow context!")
Exemplo n.º 7
0
    def compile(self, **kwargs):
        """
        Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
        a 'closure' in the traditional sense of the word.
        """
        ctx = FlyteContext.current_context()
        self._input_parameters = transform_inputs_to_parameters(ctx, self._native_interface)
        all_nodes = []
        prefix = f"{ctx.compilation_state.prefix}-{self.short_name}-" if ctx.compilation_state is not None else None
        with ctx.new_compilation_context(prefix=prefix) as comp_ctx:
            # Construct the default input promise bindings, but then override with the provided inputs, if any
            input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()])
            input_kwargs.update(kwargs)
            workflow_outputs = self._workflow_function(**input_kwargs)
            all_nodes.extend(comp_ctx.compilation_state.nodes)

        # Iterate through the workflow outputs
        bindings = []
        output_names = list(self.interface.outputs.keys())
        # The reason the length 1 case is separate is because the one output might be a list. We don't want to
        # iterate through the list here, instead we should let the binding creation unwrap it and make a binding
        # collection/map out of it.
        if len(output_names) == 1:
            if isinstance(workflow_outputs, tuple) and len(workflow_outputs) != 1:
                raise AssertionError(
                    f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}"
                )
            t = self._native_interface.outputs[output_names[0]]
            b = flytekit.annotated.promise.binding_from_python_std(
                ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t,
            )
            bindings.append(b)
        elif len(output_names) > 1:
            if not isinstance(workflow_outputs, tuple):
                raise AssertionError("The Workflow specification indicates multiple return values, received only one")
            if len(output_names) != len(workflow_outputs):
                raise Exception(f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}")
            for i, out in enumerate(output_names):
                if isinstance(workflow_outputs[i], ConditionalSection):
                    raise AssertionError("A Conditional block (if-else) should always end with an `else_()` clause")
                t = self._native_interface.outputs[out]
                b = flytekit.annotated.promise.binding_from_python_std(
                    ctx, out, self.interface.outputs[out].type, workflow_outputs[i], t,
                )
                bindings.append(b)

        # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain
        self._nodes = all_nodes
        self._output_bindings = bindings
        if not output_names:
            return None
        if len(output_names) == 1:
            return bindings[0]
        return tuple(bindings)
Exemplo n.º 8
0
 def start_branch(self, c: Case, last_case: bool = False) -> Case:
     """
     At the start of an execution of every branch this method should be called.
     :param c: -> the case that represents this branch
     :param last_case: -> a boolean that indicates if this is the last branch in the ifelseblock
     """
     self._last_case = last_case
     self._cases.append(c)
     ctx = FlyteContext.current_context()
     # In case of Local workflow execution, we will actually evaluate the expression and based on the result
     # make the branch to be active using `take_branch` method
     if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
         # This is a short-circuit for the case when the branch was taken
         # We already have a candidate case selected
         if self._selected_case is None:
             if c.expr is None or c.expr.eval() or last_case:
                 ctx = FlyteContext.current_context().execution_state
                 ctx.take_branch()
                 self._selected_case = self._cases[-1]
     return self._cases[-1]
Exemplo n.º 9
0
def current_context():
    """
    Use this method to get a handle of specific parameters available in a flyte task.

    Usage

    .. code-block::

        flytekit.current_context().logging.info(...)

    Available params are documented in :py:class:`flytekit.annotated.context_manager.ExecutionParams`.
    There are some special params, that should be available
    """
    return FlyteContext.current_context().user_space_params
Exemplo n.º 10
0
 def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps,
              rhs: Union["Promise", Any]):
     self._op = op
     self._lhs = None
     self._rhs = None
     if isinstance(lhs, Promise):
         self._lhs = lhs
         if lhs.is_ready:
             if lhs.val.scalar is None or lhs.val.scalar.primitive is None:
                 raise ValueError(
                     "Only primitive values can be used in comparison")
     if isinstance(rhs, Promise):
         self._rhs = rhs
         if rhs.is_ready:
             if rhs.val.scalar is None or rhs.val.scalar.primitive is None:
                 raise ValueError(
                     "Only primitive values can be used in comparison")
     if self._lhs is None:
         self._lhs = type_engine.TypeEngine.to_literal(
             FlyteContext.current_context(), lhs, type(lhs), None)
     if self._rhs is None:
         self._rhs = type_engine.TypeEngine.to_literal(
             FlyteContext.current_context(), rhs, type(rhs), None)
Exemplo n.º 11
0
def test_file_format_getting_python_value():
    transformer = TypeEngine.get_transformer(FlyteFile)

    ctx = FlyteContext.current_context()

    # This file probably won't exist, but it's okay. It won't be downloaded unless we try to read the thing returned
    lv = Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(
        type=BlobType(format="txt", dimensionality=0)),
                                         uri="file:///tmp/test")))

    pv = transformer.to_python_value(ctx,
                                     lv,
                                     expected_python_type=FlyteFile["txt"])
    assert isinstance(pv, FlyteFile)
    assert pv.extension() == "txt"
Exemplo n.º 12
0
    def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
        """
        Pre-execute for Sagemaker will automatically add the distributed context to the execution params, only
        if the number of execution instances is > 1. Otherwise this is considered to be a single node execution
        """
        if self._is_distributed():
            logging.info("Distributed context detected!")
            exec_state = FlyteContext.current_context().execution_state
            if exec_state and exec_state.mode == ExecutionState.Mode.TASK_EXECUTION:
                """
                    This mode indicates we are actually in a remote execute environment (within sagemaker in this case)
                """
                dist_ctx = DistributedTrainingContext.from_env()
            else:
                dist_ctx = DistributedTrainingContext.local_execute()
            return user_params.builder().add_attr("DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build()

        return user_params
Exemplo n.º 13
0
    def pre_execute(self,
                    user_params: ExecutionParameters) -> ExecutionParameters:
        import pyspark as _pyspark

        ctx = FlyteContext.current_context()
        if not (ctx.execution_state and ctx.execution_state.Mode
                == ExecutionState.Mode.TASK_EXECUTION):
            # If either of above cases is not true, then we are in local execution of this task
            # Add system spark-conf for local/notebook based execution.
            spark_conf = set()
            for k, v in self.task_config.spark_conf.items():
                spark_conf.add((k, v))
            spark_conf.add(("spark.master", "local"))
            _pyspark.SparkConf().setAll(spark_conf)

        sess = _pyspark.sql.SparkSession.builder.appName(
            f"FlyteSpark: {user_params.execution_id}").getOrCreate()
        return user_params.builder().add_attr("SPARK_SESSION", sess).build()
Exemplo n.º 14
0
    def __call__(self, *args, **kwargs):
        # When a Task is () aka __called__, there are three things we may do:
        #  a. Task Execution Mode - just run the Python function as Python normally would. Flyte steps completely
        #     out of the way.
        #  b. Compilation Mode - this happens when the function is called as part of a workflow (potentially
        #     dynamic task?). Instead of running the user function, produce promise objects and create a node.
        #  c. Workflow Execution Mode - when a workflow is being run locally. Even though workflows are functions
        #     and everything should be able to be passed through naturally, we'll want to wrap output values of the
        #     function into objects, so that potential .with_cpu or other ancillary functions can be attached to do
        #     nothing. Subsequent tasks will have to know how to unwrap these. If by chance a non-Flyte task uses a
        #     task output as an input, things probably will fail pretty obviously.
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                f"When calling tasks, only keyword args are supported. "
                f"Aborting execution as detected {len(args)} positional args {args}"
            )

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
            return self.compile(ctx, *args, **kwargs)
        elif (ctx.execution_state is not None and ctx.execution_state.mode
              == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION):
            if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
                if self.python_interface and self.python_interface.output_tuple_name:
                    variables = [
                        k for k in self.python_interface.outputs.keys()
                    ]
                    output_tuple = collections.namedtuple(
                        self.python_interface.output_tuple_name, variables)
                    nones = [
                        None for _ in self.python_interface.outputs.keys()
                    ]
                    return output_tuple(*nones)
                else:
                    # Should we return multiple None's here?
                    return None
            return self._local_execute(ctx, **kwargs)
        else:
            logger.warning("task run without context - executing raw function")
            new_user_params = self.pre_execute(ctx.user_space_params)
            with ctx.new_execution_context(
                    mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION,
                    execution_params=new_user_params):
                return self.execute(**kwargs)
    def compile_into_workflow(
        self, ctx: FlyteContext, task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        with ctx.new_compilation_context(prefix="dynamic"):
            # TODO: Resolve circular import
            from flytekit.common.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(interruptible=False)

            self._wf = Workflow(task_function,
                                metadata=workflow_metadata,
                                default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            sdk_workflow = get_serializable(ctx.serialization_settings, wf)

            # If no nodes were produced, let's just return the strict outputs
            if len(sdk_workflow.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sdk_workflow._outputs
                    })

            # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller.
            tasks = set()
            sub_workflows = set()
            for n in sdk_workflow.nodes:
                self.aggregate(tasks, sub_workflows, n)

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(sdk_workflow.nodes),
                tasks=list(tasks),
                nodes=sdk_workflow.nodes,
                outputs=sdk_workflow._outputs,
                subworkflows=list(sub_workflows),
            )

            return dj_spec
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the _local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContext.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            with ctx.new_execution_context(ExecutionState.Mode.TASK_EXECUTION):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return task_function(**kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            return self.compile_into_workflow(ctx, task_function, **kwargs)
Exemplo n.º 17
0
    def __call__(self, *args, **kwargs):
        if len(args) > 0:
            raise AssertionError(
                "Only Keyword Arguments are supported for launch plan executions"
            )

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            inputs = self.saved_inputs
            inputs.update(kwargs)
            return create_and_link_node(
                ctx,
                entity=self,
                interface=self.workflow._native_interface,
                **inputs)
        else:
            # Calling a launch plan should just forward the call to the workflow, nothing more. But let's add in the
            # saved inputs.
            inputs = self.saved_inputs
            inputs.update(kwargs)
            return self.workflow(*args, **inputs)
Exemplo n.º 18
0
 def if_(self, expr: bool) -> Case:
     ctx = FlyteContext.current_context()
     if ctx.execution_state:
         if ctx.execution_state.branch_eval_mode is not None:
             """
             TODO implement nested branches
             """
             raise NotImplementedError("Nested branches are not yet supported")
         if ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
             """
             In case of local workflow execution we should ensure a conditional section
             is created so that skipped branches result in tasks not being executed
             """
             ctx.execution_state.enter_conditional_section()
     elif ctx.compilation_state:
         if ctx.compilation_state.is_in_a_branch():
             """
             TODO implement nested branches
             """
             raise NotImplementedError("Nested branches are not yet supported")
         ctx.compilation_state.enter_conditional_section()
     return self._condition._if(expr)
Exemplo n.º 19
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This method translates Flyte's Type system based input values and invokes the actual call to the executor
        This method is also invoked during runtime.
            `VoidPromise` is returned in the case when the task itself declares no outputs.
            `Literal Map` is returned when the task returns either one more outputs in the declaration. Individual outputs
                           may be none
            `DynamicJobSpec` is returned when a dynamic workflow is executed
        """

        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with ctx.new_execution_context(
                mode=ctx.execution_state.mode,
                execution_params=new_user_params,
                working_dir=ctx.execution_state.working_dir,
        ) as exec_ctx:
            # TODO We could support default values here too - but not part of the plan right now
            # Translate the input literals to Python native
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, self.python_interface.inputs)

            # TODO: Logger should auto inject the current context information to indicate if the task is running within
            #   a workflow or a subworkflow etc
            logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
            native_outputs = None
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.info(
                f"Task executed successfully in user level, outputs: {native_outputs}"
            )
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(self.interface.outputs.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                if self.python_interface.output_tuple_name and isinstance(
                        native_outputs, tuple):
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs[0]
                    }
                else:
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs
                    }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self.interface.outputs[k].type
                py_type = self.get_type_for_output_var(k, v)
                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
Exemplo n.º 20
0
def test_default():
    ctx = FlyteContext.current_context()
    assert ctx.file_access is not None
Exemplo n.º 21
0
def _handle_annotated_task(task_def: PythonTask, inputs: str,
                           output_prefix: str, raw_output_data_prefix: str):
    """
    Entrypoint for all PythonTask extensions
    """
    _click.echo("Running native-typed task")
    cloud_provider = _platform_config.CLOUD_PROVIDER.get()
    log_level = _internal_config.LOGGING_LEVEL.get(
    ) or _sdk_config.LOGGING_LEVEL.get()
    _logging.getLogger().setLevel(log_level)

    ctx = FlyteContext.current_context()

    # Create directories
    user_workspace_dir = ctx.file_access.local_access.get_random_directory()
    _click.echo(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=_internal_config.EXECUTION_PROJECT.get(),
            domain=_internal_config.EXECUTION_DOMAIN.get(),
            name=_internal_config.EXECUTION_NAME.get(),
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            "{}.{}.{}.user_stats".format(
                _internal_config.TASK_PROJECT.get()
                or _internal_config.PROJECT.get(),
                _internal_config.TASK_DOMAIN.get()
                or _internal_config.DOMAIN.get(),
                _internal_config.TASK_NAME.get()
                or _internal_config.NAME.get(),
            ),
            tags={
                "exec_project": _internal_config.EXECUTION_PROJECT.get(),
                "exec_domain": _internal_config.EXECUTION_DOMAIN.get(),
                "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(),
                "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(),
                "api_version": _api_version,
            },
        ),
        logging=_logging,
        tmp_dir=user_workspace_dir,
    )

    if cloud_provider == _constants.CloudProvider.AWS:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(),
            remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.GCP:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(),
            remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.LOCAL:
        # A fake remote using the local disk will automatically be created
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get())
    else:
        raise Exception(f"Bad cloud provider {cloud_provider}")

    with ctx.new_file_access_context(file_access_provider=file_access) as ctx:
        # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing.
        env = {
            _internal_config.CONFIGURATION_PATH.env_var:
            _internal_config.CONFIGURATION_PATH.get(),
            _internal_config.IMAGE.env_var:
            _internal_config.IMAGE.get(),
        }

        serialization_settings = SerializationSettings(
            project=_internal_config.TASK_PROJECT.get(),
            domain=_internal_config.TASK_DOMAIN.get(),
            version=_internal_config.TASK_VERSION.get(),
            image_config=get_image_config(),
            env=env,
        )

        # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin,
        # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task
        # identifier for t1.
        with ctx.new_serialization_settings(
                serialization_settings=serialization_settings) as ctx:
            # Because execution states do not look up the context chain, it has to be made last
            with ctx.new_execution_context(
                    mode=ExecutionState.Mode.TASK_EXECUTION,
                    execution_params=execution_parameters) as ctx:
                _dispatch_execute(ctx, task_def, inputs, output_prefix)
Exemplo n.º 22
0
def test_dict_transformer():
    d = DictTransformer()

    def assert_struct(lit: LiteralType):
        assert lit is not None
        assert lit.simple == SimpleType.STRUCT

    def recursive_assert(lit: LiteralType,
                         expected: LiteralType,
                         expected_depth: int = 1,
                         curr_depth: int = 0):
        assert curr_depth <= expected_depth
        assert lit is not None
        if lit.map_value_type is None:
            assert lit == expected
            return
        recursive_assert(lit.map_value_type, expected, expected_depth,
                         curr_depth + 1)

    # Type inference
    assert_struct(d.get_literal_type(dict))
    assert_struct(d.get_literal_type(typing.Dict[int, int]))
    recursive_assert(d.get_literal_type(typing.Dict[str, str]),
                     LiteralType(simple=SimpleType.STRING))
    recursive_assert(d.get_literal_type(typing.Dict[str, int]),
                     LiteralType(simple=SimpleType.INTEGER))
    recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]),
                     LiteralType(simple=SimpleType.DATETIME))
    recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]),
                     LiteralType(simple=SimpleType.DURATION))
    recursive_assert(d.get_literal_type(typing.Dict[str, dict]),
                     LiteralType(simple=SimpleType.STRUCT))
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]),
        LiteralType(simple=SimpleType.STRING),
        expected_depth=2,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[int, str]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=2,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[str,
                                                                    str]]]),
        LiteralType(simple=SimpleType.STRING),
        expected_depth=3,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[str,
                                                                    dict]]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=3,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[int,
                                                                    dict]]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=2,
    )

    ctx = FlyteContext.current_context()

    lit = d.to_literal(ctx, {}, typing.Dict, LiteralType(SimpleType.STRUCT))
    pv = d.to_python_value(ctx, lit, typing.Dict)
    assert pv == {}

    # Literal to python
    with pytest.raises(TypeError):
        d.to_python_value(
            ctx, Literal(scalar=Scalar(primitive=Primitive(integer=10))), dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(), dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})),
                          dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})),
                          typing.Dict[int, str])

    d.to_python_value(
        ctx,
        Literal(map=LiteralMap(
            literals={
                "x": Literal(scalar=Scalar(primitive=Primitive(integer=1)))
            })),
        typing.Dict[str, int],
    )
Exemplo n.º 23
0
def create_node(
        entity: Union[PythonTask, LaunchPlan, Workflow], *args,
        **kwargs) -> Union[Node, VoidPromise, Type[collections.namedtuple]]:
    """
    This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or
    don't produce outputs. For example, if you have t1() and t2(), both of which do not take in nor produce any
    outputs, how do you specify that t2 should run before t1?

        t1_node = create_node(t1)
        t2_node = create_node(t2)

        t2_node.runs_before(t1_node)
        # OR
        t2_node >> t1_node

    This works for tasks that take inputs as well, say a ``t3(in1: int)``

        t3_node = create_node(t3, in1=some_int)  # basically calling t3(in1=some_int)

    You can still use this method to handle setting certain overrides

        t3_node = create_node(t3, in1=some_int).with_overrides(...)

    Outputs, if there are any, will be accessible. A `t4() -> (int, str)`

        t4_node = create_node(t4)

        in compilation node.o0 has the promise.
        t5(in1=t4_node.o0)

        in local workflow execution, what is the node?  Can it just be the named tuple?
        t5(in1=t4_node.o0)

    @workflow
    def wf():
        create_node(sub_wf)
        create_node(wf2)

    @dynamic
    def sub_wf():
        create_node(other_sub)
        create_node(task)

    If t1 produces only one output, note that in local execution, you still get a wrapper object that
    needs to be dereferenced by the output name.

        t1_node = create_node(t1)
        t2(t1_node.o0)

    """
    if len(args) > 0:
        raise _user_exceptions.FlyteAssertion(
            f"Only keyword args are supported to pass inputs to workflows and tasks."
            f"Aborting execution as detected {len(args)} positional args {args}"
        )

    if not isinstance(entity, PythonTask) and not isinstance(
            entity, Workflow) and not isinstance(entity, LaunchPlan):
        raise AssertionError("Should be but it's not")

    # This function is only called from inside workflows and dynamic tasks.
    # That means there are two scenarios we need to take care of, compilation and local workflow execution.

    # When compiling, calling the entity will create a node.
    ctx = FlyteContext.current_context()
    if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:

        outputs = entity(**kwargs)
        # This is always the output of create_and_link_node which returns create_task_output, which can be
        # VoidPromise, Promise, or our custom namedtuple of Promises.
        node = ctx.compilation_state.nodes[-1]

        # If a VoidPromise, just return the node.
        if isinstance(outputs, VoidPromise):
            return node

        # If a Promise or custom namedtuple of Promises, we need to attach each output as an attribute to the node.
        if entity.python_interface.outputs:
            if isinstance(outputs, tuple):
                for output_name in entity.python_interface.output_names:
                    attr = getattr(outputs, output_name)
                    if attr is None:
                        raise Exception(
                            f"Output {output_name} in outputs when calling {entity.name} is empty {attr}."
                        )
                    setattr(node, output_name, attr)
            else:
                output_names = entity.python_interface.output_names
                if len(output_names) != 1:
                    raise Exception(
                        f"Output of length 1 expected but {len(output_names)} found"
                    )

                setattr(node, output_names[0],
                        outputs)  # This should be a singular Promise

        return node

    # Handling local execution
    elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
        if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
            logger.warning(
                f"Manual node creation cannot be used in branch logic {entity.name}"
            )
            raise Exception(
                "Being more restrictive for now and disallowing manual node creation in branch logic"
            )

        # This the output of __call__ under local execute conditions which means this is the output of _local_execute
        # which means this is the output of create_task_output with Promises containing values (or a VoidPromise)
        results = entity(**kwargs)

        # If it's a VoidPromise, let's just return it, it shouldn't get used anywhere and if it does, we want an error
        # The reason we return it if it's a tuple is to handle the case where the task returns a typing.NamedTuple.
        # In that case, it's already a tuple and we don't need to further tupletize.
        if isinstance(results, VoidPromise) or isinstance(results, tuple):
            return results

        output_names = entity.python_interface.output_names

        if not output_names:
            raise Exception(
                f"Non-VoidPromise received {results} but interface for {entity.name} doesn't have outputs"
            )

        if len(output_names) == 1:
            # See explanation above for why we still tupletize a single element.
            return entity.python_interface.output_tuple(results)

        return entity.python_interface.output_tuple(*results)

    else:
        raise Exception(
            f"Cannot use explicit run to call Flyte entities {entity.name}")