예제 #1
0
def test_two(two_sample_inputs):
    my_input = two_sample_inputs[0]
    my_input_2 = two_sample_inputs[1]

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteFile]:
        x = []
        for aa in a:
            x.append(aa.main_product)
        return x

    with FlyteContextManager.with_context(
        FlyteContextManager.current_context().with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ) as ctx:
        with FlyteContextManager.with_context(
            ctx.with_execution_state(
                ctx.execution_state.with_params(
                    mode=ExecutionState.Mode.TASK_EXECUTION,
                )
            )
        ) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]}
            )
            dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
            assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
예제 #2
0
def test_levels():
    ctx = FlyteContextManager.current_context()
    b = ctx.new_builder()
    b.flyte_client = SampleTestClass(value=1)
    with FlyteContextManager.with_context(b) as outer:
        assert outer.flyte_client.value == 1
        b = outer.new_builder()
        b.flyte_client = SampleTestClass(value=2)
        with FlyteContextManager.with_context(b) as ctx:
            assert ctx.flyte_client.value == 2

        with FlyteContextManager.with_context(outer.with_new_compilation_state()) as ctx:
            assert ctx.flyte_client.value == 1
예제 #3
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContextManager.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            updated_exec_state = ctx.execution_state.with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(updated_exec_state)):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return exception_scopes.user_entry_point(task_function)(
                    **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            return self.compile_into_workflow(ctx, task_function, **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
            return exception_scopes.user_entry_point(task_function)(**kwargs)

        raise ValueError(
            f"Invalid execution provided, execution state: {ctx.execution_state}"
        )
예제 #4
0
    def add_entity(self, entity: PythonAutoContainerTask, **kwargs) -> Node:
        """
        Anytime you add an entity, all the inputs to the entity must be bound.
        """
        # circular import
        from flytekit.core.node_creation import create_node

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            raise Exception("Can't already be compiling")
        with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
            n = create_node(entity=entity, **kwargs)

            def get_input_values(input_value):
                if isinstance(input_value, list):
                    input_promises = []
                    for x in input_value:
                        input_promises.extend(get_input_values(x))
                    return input_promises
                if isinstance(input_value, dict):
                    input_promises = []
                    for _, v in input_value.items():
                        input_promises.extend(get_input_values(v))
                    return input_promises
                else:
                    return [input_value]

            # Every time an entity is added, mark it as used.
            for input_value in get_input_values(kwargs):
                if input_value in self._unbound_inputs:
                    self._unbound_inputs.remove(input_value)
            return n
예제 #5
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the _local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContextManager.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            updated_exec_state = ctx.execution_state.with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(updated_exec_state)):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return task_function(**kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            is_fast_execution = bool(
                ctx.execution_state and ctx.execution_state.additional_context
                and ctx.execution_state.additional_context.get(
                    "dynamic_addl_distro"))
            if is_fast_execution:
                ctx = ctx.with_serialization_settings(
                    SerializationSettings.new_builder(
                    ).with_fast_serialization_settings(
                        FastSerializationSettings(enabled=True)).build())

            return self.compile_into_workflow(ctx, task_function, **kwargs)
예제 #6
0
def serialize(
    pkgs: typing.List[str],
    settings: SerializationSettings,
    local_source_root: typing.Optional[str] = None,
    options: typing.Optional[Options] = None,
) -> typing.List[RegistrableEntity]:
    """
    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param options:
    :param settings: SerializationSettings to be used
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    """

    ctx = FlyteContextManager.current_context().with_serialization_settings(
        settings)
    with FlyteContextManager.with_context(ctx) as ctx:
        # Scan all modules. the act of loading populates the global singleton that contains all objects
        with module_loader.add_sys_path(local_source_root):
            click.secho(
                f"Loading packages {pkgs} under source root {local_source_root}",
                fg="yellow")
            module_loader.just_load_modules(pkgs=pkgs)

        registrable_entities = get_registrable_entities(ctx, options=options)
        click.secho(
            f"Successfully serialized {len(registrable_entities)} flyte objects",
            fg="green")
        return registrable_entities
예제 #7
0
    def add_workflow_output(
        self, output_name: str, p: Union[Promise, List[Promise], Dict[str, Promise]], python_type: Optional[Type] = None
    ):
        """
        Add an output with the given name from the given node output.
        """
        if output_name in self._python_interface.outputs:
            raise FlyteValidationException(f"Output {output_name} already exists in workflow {self.name}")

        if python_type is None:
            if type(p) == list or type(p) == dict:
                raise FlyteValidationException(
                    f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}"
                    f" starting with the container type (e.g. List[int]"
                )
            python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var]
            logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}")

        flyte_type = TypeEngine.to_literal_type(python_type=python_type)

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            raise Exception("Can't already be compiling")
        with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
            b = binding_from_python_std(
                ctx, output_name, expected_literal_type=flyte_type, t_value=p, t_value_type=python_type
            )
            self._output_bindings.append(b)
            self._python_interface = self._python_interface.with_outputs(extra_outputs={output_name: python_type})
            self._interface = transform_interface_to_typed_interface(self._python_interface)
예제 #8
0
def test_dc_dyn_directory(folders_and_files_setup):
    proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
    proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

    my_input_gcs = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/one"),
            external_data_dir=FlyteDirectory("gs://my-bucket/two"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    my_input_gcs_2 = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/three"),
            external_data_dir=FlyteDirectory("gs://my-bucket/four"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteDirectory]:
        x = []
        for aa in a:
            x.append(aa.apriori_config.external_data_dir)

        return x

    ctx = FlyteContextManager.current_context()
    cb = (
        ctx.new_builder()
        .with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
        .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
    )
    with FlyteContextManager.with_context(cb) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(
            ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]}
        )
        dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
        assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two"
        assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
예제 #9
0
def test_additional_context():
    ctx = FlyteContext.current_context()
    with FlyteContextManager.with_context(
            ctx.with_execution_state(ctx.new_execution_state().with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION,
                additional_context={
                    1: "outer",
                    2: "foo"
                }))) as exec_ctx_outer:
        with FlyteContextManager.with_context(
                ctx.with_execution_state(
                    exec_ctx_outer.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION,
                        additional_context={
                            1: "inner",
                            3: "baz"
                        }))) as exec_ctx_inner:
            assert exec_ctx_inner.execution_state.additional_context == {
                1: "inner",
                2: "foo",
                3: "baz"
            }
예제 #10
0
    def __call__(self, *args, **kwargs):
        # When a Task is () aka __called__, there are three things we may do:
        #  a. Task Execution Mode - just run the Python function as Python normally would. Flyte steps completely
        #     out of the way.
        #  b. Compilation Mode - this happens when the function is called as part of a workflow (potentially
        #     dynamic task?). Instead of running the user function, produce promise objects and create a node.
        #  c. Workflow Execution Mode - when a workflow is being run locally. Even though workflows are functions
        #     and everything should be able to be passed through naturally, we'll want to wrap output values of the
        #     function into objects, so that potential .with_cpu or other ancillary functions can be attached to do
        #     nothing. Subsequent tasks will have to know how to unwrap these. If by chance a non-Flyte task uses a
        #     task output as an input, things probably will fail pretty obviously.
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                f"When calling tasks, only keyword args are supported. "
                f"Aborting execution as detected {len(args)} positional args {args}"
            )

        ctx = FlyteContextManager.current_context()
        if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
            return self.compile(ctx, *args, **kwargs)
        elif (ctx.execution_state is not None and ctx.execution_state.mode
              == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION):
            if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
                if self.python_interface and self.python_interface.output_tuple_name:
                    variables = [
                        k for k in self.python_interface.outputs.keys()
                    ]
                    output_tuple = collections.namedtuple(
                        self.python_interface.output_tuple_name, variables)
                    nones = [
                        None for _ in self.python_interface.outputs.keys()
                    ]
                    return output_tuple(*nones)
                else:
                    # Should we return multiple None's here?
                    return None
            return self._local_execute(ctx, **kwargs)
        else:
            logger.warning("task run without context - executing raw function")
            new_user_params = self.pre_execute(ctx.user_space_params)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(
                        ctx.execution_state.with_params(
                            mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION,
                            user_space_params=new_user_params))):
                return self.execute(**kwargs)
예제 #11
0
    def add_entity(self, entity: Union[PythonTask, LaunchPlan, WorkflowBase],
                   **kwargs) -> Node:
        """
        Anytime you add an entity, all the inputs to the entity must be bound.
        """
        # circular import
        from flytekit.core.node_creation import create_node

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            raise Exception("Can't already be compiling")
        with FlyteContextManager.with_context(
                ctx.with_compilation_state(self.compilation_state)) as ctx:
            n = create_node(entity=entity, **kwargs)

            def get_input_values(input_value):
                if isinstance(input_value, list):
                    input_promises = []
                    for x in input_value:
                        input_promises.extend(get_input_values(x))
                    return input_promises
                if isinstance(input_value, dict):
                    input_promises = []
                    for _, v in input_value.items():
                        input_promises.extend(get_input_values(v))
                    return input_promises
                else:
                    return [input_value]

            # Every time an entity is added, mark it as used. The above function though will gather all the input
            # values but we're only interested in the ones that are Promises so let's filter for those.
            # There's probably a way to clean this up, maybe key off of the name instead of value?
            all_input_values = get_input_values(kwargs)
            for input_value in filter(lambda x: isinstance(x, Promise),
                                      all_input_values):
                if input_value in self._unbound_inputs:
                    self._unbound_inputs.remove(input_value)
            return n
예제 #12
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This function is largely similar to the base PythonTask, with the exception that we have to infer the Python
        interface before executing. Also, we refer to ``self.task_template`` rather than just ``self`` similar to task
        classes that derive from the base ``PythonTask``.
        """
        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        user_space_params=new_user_params))) as exec_ctx:
            # Added: Have to reverse the Python interface from the task template Flyte interface
            # See docstring for more details.
            guessed_python_input_types = TypeEngine.guess_python_types(
                self.task_template.interface.inputs)
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, guessed_python_input_types)

            logger.info(
                f"Invoking FlyteTask executor {self.task_template.id.name} with inputs: {native_inputs}"
            )
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.debug("Task executed successfully in user level")
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(
                self.task_template.interface.outputs.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                # Deleted some stuff
                native_outputs_as_map = {
                    expected_output_names[0]: native_outputs
                }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self.task_template.interface.outputs[k].type
                py_type = type(v)

                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.task_template.id.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
예제 #13
0
    def compile_into_workflow(
        self, ctx: FlyteContext, task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        if not ctx.compilation_state:
            cs = ctx.new_compilation_state("dynamic")
        else:
            cs = ctx.compilation_state.with_params(prefix="dynamic")

        with FlyteContextManager.with_context(ctx.with_compilation_state(cs)):
            # TODO: Resolve circular import
            from flytekit.common.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(
                interruptible=self.metadata.interruptible if self.metadata.
                interruptible is not None else False)

            self._wf = PythonFunctionWorkflow(task_function,
                                              metadata=workflow_metadata,
                                              default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            model_entities = OrderedDict()
            # See comment on reference entity checking a bit down below in this function.
            # This is the only circular dependency between the translator.py module and the rest of the flytekit
            # authoring experience.
            workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable(
                model_entities, ctx.serialization_settings, wf)

            # If no nodes were produced, let's just return the strict outputs
            if len(workflow_spec.template.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in workflow_spec.template.outputs
                    })

            # This is not great. The translator.py module is relied on here (see comment above) to get the tasks and
            # subworkflow definitions. However we want to ensure that reference tasks and reference sub workflows are
            # not used.
            # TODO: Replace None with a class.
            for value in model_entities.values():
                if value is None:
                    raise Exception(
                        "Reference tasks are not allowed in the dynamic - a network call is necessary "
                        "in order to retrieve the structure of the reference task."
                    )

            # Gather underlying TaskTemplates that get referenced. Launch plans are handled by propeller. Subworkflows
            # should already be in the workflow spec.
            tts = [
                v.template for v in model_entities.values()
                if isinstance(v, task_models.TaskSpec)
            ]

            if ctx.serialization_settings.should_fast_serialize():
                if (not ctx.execution_state
                        or not ctx.execution_state.additional_context
                        or not ctx.execution_state.additional_context.get(
                            "dynamic_addl_distro")):
                    raise AssertionError(
                        "Compilation for a dynamic workflow called in fast execution mode but no additional code "
                        "distribution could be retrieved")
                logger.warn(
                    f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}"
                )
                for task_template in tts:
                    sanitized_args = []
                    for arg in task_template.container.args:
                        if arg == "{{ .remote_package_path }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_addl_distro"))
                        elif arg == "{{ .dest_dir }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_dest_dir", "."))
                        else:
                            sanitized_args.append(arg)
                    del task_template.container.args[:]
                    task_template.container.args.extend(sanitized_args)

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(workflow_spec.template.nodes),
                tasks=tts,
                nodes=workflow_spec.template.nodes,
                outputs=workflow_spec.template.outputs,
                subworkflows=workflow_spec.sub_workflows,
            )

            return dj_spec
예제 #14
0
    def dispatch_execute(
        self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
    ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
        """
        This method translates Flyte's Type system based input values and invokes the actual call to the executor
        This method is also invoked during runtime.

        * ``VoidPromise`` is returned in the case when the task itself declares no outputs.
        * ``Literal Map`` is returned when the task returns either one more outputs in the declaration. Individual outputs
          may be none
        * ``DynamicJobSpec`` is returned when a dynamic workflow is executed
        """

        # Invoked before the task is executed
        new_user_params = self.pre_execute(ctx.user_space_params)

        # Create another execution context with the new user params, but let's keep the same working dir
        with FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        user_space_params=new_user_params))) as exec_ctx:
            # TODO We could support default values here too - but not part of the plan right now
            # Translate the input literals to Python native
            native_inputs = TypeEngine.literal_map_to_kwargs(
                exec_ctx, input_literal_map, self.python_interface.inputs)

            # TODO: Logger should auto inject the current context information to indicate if the task is running within
            #   a workflow or a subworkflow etc
            logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
            try:
                native_outputs = self.execute(**native_inputs)
            except Exception as e:
                logger.exception(f"Exception when executing {e}")
                raise e

            logger.info(
                f"Task executed successfully in user level, outputs: {native_outputs}"
            )
            # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
            # bubbled up to be handled at the callee layer.
            native_outputs = self.post_execute(new_user_params, native_outputs)

            # Short circuit the translation to literal map because what's returned may be a dj spec (or an
            # already-constructed LiteralMap if the dynamic task was a no-op), not python native values
            if isinstance(native_outputs,
                          _literal_models.LiteralMap) or isinstance(
                              native_outputs, _dynamic_job.DynamicJobSpec):
                return native_outputs

            expected_output_names = list(self._outputs_interface.keys())
            if len(expected_output_names) == 1:
                # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
                # length one. That convention is used for naming outputs - and single-length-NamedTuples are
                # particularly troublesome but elegant handling of them is not a high priority
                # Again, we're using the output_tuple_name as a proxy.
                if self.python_interface.output_tuple_name and isinstance(
                        native_outputs, tuple):
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs[0]
                    }
                else:
                    native_outputs_as_map = {
                        expected_output_names[0]: native_outputs
                    }
            elif len(expected_output_names) == 0:
                native_outputs_as_map = {}
            else:
                native_outputs_as_map = {
                    expected_output_names[i]: native_outputs[i]
                    for i, _ in enumerate(native_outputs)
                }

            # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
            # built into the IDL that all the values of a literal map are of the same type.
            literals = {}
            for k, v in native_outputs_as_map.items():
                literal_type = self._outputs_interface[k].type
                py_type = self.get_type_for_output_var(k, v)

                if isinstance(v, tuple):
                    raise AssertionError(
                        f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}"
                    )
                try:
                    literals[k] = TypeEngine.to_literal(
                        exec_ctx, v, py_type, literal_type)
                except Exception as e:
                    raise AssertionError(
                        f"failed to convert return value for var {k}") from e

            outputs_literal_map = _literal_models.LiteralMap(literals=literals)
            # After the execute has been successfully completed
            return outputs_literal_map
예제 #15
0
    def compile(self, **kwargs):
        """
        Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
        a 'closure' in the traditional sense of the word.
        """
        ctx = FlyteContextManager.current_context()
        self._input_parameters = transform_inputs_to_parameters(
            ctx, self.python_interface)
        all_nodes = []
        prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else ""

        with FlyteContextManager.with_context(
                ctx.with_compilation_state(
                    CompilationState(prefix=prefix,
                                     task_resolver=self))) as comp_ctx:
            # Construct the default input promise bindings, but then override with the provided inputs, if any
            input_kwargs = construct_input_promises(
                [k for k in self.interface.inputs.keys()])
            input_kwargs.update(kwargs)
            workflow_outputs = exception_scopes.user_entry_point(
                self._workflow_function)(**input_kwargs)
            all_nodes.extend(comp_ctx.compilation_state.nodes)

            # This little loop was added as part of the task resolver change. The task resolver interface itself is
            # more or less stateless (the future-proofing get_all_tasks function notwithstanding). However the
            # implementation of the TaskResolverMixin that this workflow class inherits from (ClassStorageTaskResolver)
            # does store state. This loop adds Tasks that are defined within the body of the workflow to the workflow
            # object itself.
            for n in comp_ctx.compilation_state.nodes:
                if isinstance(n.flyte_entity, PythonAutoContainerTask
                              ) and n.flyte_entity.task_resolver == self:
                    logger.debug(
                        f"WF {self.name} saving task {n.flyte_entity.name}")
                    self.add(n.flyte_entity)

        # Iterate through the workflow outputs
        bindings = []
        output_names = list(self.interface.outputs.keys())
        # The reason the length 1 case is separate is because the one output might be a list. We don't want to
        # iterate through the list here, instead we should let the binding creation unwrap it and make a binding
        # collection/map out of it.
        if len(output_names) == 1:
            if isinstance(workflow_outputs, tuple):
                if len(workflow_outputs) != 1:
                    raise AssertionError(
                        f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}"
                    )
                if self.python_interface.output_tuple_name is None:
                    raise AssertionError(
                        "Outputs specification for Workflow does not define a tuple, but return value is a tuple"
                    )
                workflow_outputs = workflow_outputs[0]
            t = self.python_interface.outputs[output_names[0]]
            b = binding_from_python_std(
                ctx,
                output_names[0],
                self.interface.outputs[output_names[0]].type,
                workflow_outputs,
                t,
            )
            bindings.append(b)
        elif len(output_names) > 1:
            if not isinstance(workflow_outputs, tuple):
                raise AssertionError(
                    "The Workflow specification indicates multiple return values, received only one"
                )
            if len(output_names) != len(workflow_outputs):
                raise Exception(
                    f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}"
                )
            for i, out in enumerate(output_names):
                if isinstance(workflow_outputs[i], ConditionalSection):
                    raise AssertionError(
                        "A Conditional block (if-else) should always end with an `else_()` clause"
                    )
                t = self.python_interface.outputs[out]
                b = binding_from_python_std(
                    ctx,
                    out,
                    self.interface.outputs[out].type,
                    workflow_outputs[i],
                    t,
                )
                bindings.append(b)

        # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain
        self._nodes = all_nodes
        self._output_bindings = bindings
        if not output_names:
            return None
        if len(output_names) == 1:
            return bindings[0]
        return tuple(bindings)
예제 #16
0
    def compile_into_workflow(
        self, ctx: FlyteContext, is_fast_execution: bool,
        task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        if not ctx.compilation_state:
            cs = ctx.new_compilation_state("dynamic")
        else:
            cs = ctx.compilation_state.with_params(prefix="dynamic")

        with FlyteContextManager.with_context(ctx.with_compilation_state(cs)):
            # TODO: Resolve circular import
            from flytekit.common.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(interruptible=False)

            self._wf = PythonFunctionWorkflow(task_function,
                                              metadata=workflow_metadata,
                                              default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            sdk_workflow = get_serializable(OrderedDict(),
                                            ctx.serialization_settings, wf,
                                            is_fast_execution)

            # If no nodes were produced, let's just return the strict outputs
            if len(sdk_workflow.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sdk_workflow._outputs
                    })

            # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller.
            tasks = set()
            sub_workflows = set()
            for n in sdk_workflow.nodes:
                self.aggregate(tasks, sub_workflows, n)

            if is_fast_execution:
                if (not ctx.execution_state
                        or not ctx.execution_state.additional_context
                        or not ctx.execution_state.additional_context.get(
                            "dynamic_addl_distro")):
                    raise AssertionError(
                        "Compilation for a dynamic workflow called in fast execution mode but no additional code "
                        "distribution could be retrieved")
                logger.warn(
                    f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}"
                )
                sanitized_tasks = set()
                for task in tasks:
                    sanitized_args = []
                    for arg in task.container.args:
                        if arg == "{{ .remote_package_path }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_addl_distro"))
                        elif arg == "{{ .dest_dir }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_dest_dir", "."))
                        else:
                            sanitized_args.append(arg)
                    del task.container.args[:]
                    task.container.args.extend(sanitized_args)
                    sanitized_tasks.add(task)

                tasks = sanitized_tasks

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(sdk_workflow.nodes),
                tasks=list(tasks),
                nodes=sdk_workflow.nodes,
                outputs=sdk_workflow._outputs,
                subworkflows=list(sub_workflows),
            )

            return dj_spec
예제 #17
0
    def compile_into_workflow(
        self, ctx: FlyteContext, task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        """
        In the case of dynamic workflows, this function will produce a workflow definition at execution time which will
        then proceed to be executed.
        """
        # TODO: circular import
        from flytekit.core.task import ReferenceTask

        if not ctx.compilation_state:
            cs = ctx.new_compilation_state(prefix="d")
        else:
            cs = ctx.compilation_state.with_params(prefix="d")

        with FlyteContextManager.with_context(ctx.with_compilation_state(cs)):
            # TODO: Resolve circular import
            from flytekit.tools.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(
                interruptible=self.metadata.interruptible if self.metadata.
                interruptible is not None else False)

            self._wf = PythonFunctionWorkflow(task_function,
                                              metadata=workflow_metadata,
                                              default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            model_entities = OrderedDict()
            # See comment on reference entity checking a bit down below in this function.
            # This is the only circular dependency between the translator.py module and the rest of the flytekit
            # authoring experience.
            workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable(
                model_entities, ctx.serialization_settings, wf)

            # If no nodes were produced, let's just return the strict outputs
            if len(workflow_spec.template.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in workflow_spec.template.outputs
                    })

            # Gather underlying TaskTemplates that get referenced.
            tts = []
            for entity, model in model_entities.items():
                # We only care about gathering tasks here. Launch plans are handled by
                # propeller. Subworkflows should already be in the workflow spec.
                if not isinstance(entity, Task) and not isinstance(
                        entity, task_models.TaskTemplate):
                    continue

                # Handle FlyteTask
                if isinstance(entity, task_models.TaskTemplate):
                    tts.append(entity)
                    continue

                # We are currently not supporting reference tasks since these will
                # require a network call to flyteadmin to populate the TaskTemplate
                # model
                if isinstance(entity, ReferenceTask):
                    raise Exception(
                        "Reference tasks are currently unsupported within dynamic tasks"
                    )

                if not isinstance(model, task_models.TaskSpec):
                    raise TypeError(
                        f"Unexpected type for serialized form of task. Expected {task_models.TaskSpec}, but got {type(model)}"
                    )

                # Store the valid task template so that we can pass it to the
                # DynamicJobSpec later
                tts.append(model.template)

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(workflow_spec.template.nodes),
                tasks=tts,
                nodes=workflow_spec.template.nodes,
                outputs=workflow_spec.template.outputs,
                subworkflows=workflow_spec.sub_workflows,
            )

            return dj_spec
예제 #18
0
    def __call__(self, *args, **kwargs):
        """
        The call pattern for Workflows is close to, but not exactly, the call pattern for Tasks. For local execution,
        it goes

        __call__ -> _local_execute -> execute

        From execute, different things happen for the two Workflow styles. For PythonFunctionWorkflows, the Python
        function is run, for the ImperativeWorkflow, each node is run one at a time.
        """
        if len(args) > 0:
            raise AssertionError("Only Keyword Arguments are supported for Workflow executions")

        ctx = FlyteContextManager.current_context()

        # Get default agruements and override with kwargs passed in
        input_kwargs = self.python_interface.default_inputs_as_kwargs
        input_kwargs.update(kwargs)

        # The first condition is compilation.
        if ctx.compilation_state is not None:
            return create_and_link_node(ctx, entity=self, interface=self.python_interface, **input_kwargs)

        # This condition is hit when this workflow (self) is being called as part of a parent's workflow local run.
        # The context specifying the local workflow execution has already been set.
        elif (
            ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
        ):
            if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
                if self.python_interface and self.python_interface.output_tuple_name:
                    variables = [k for k in self.python_interface.outputs.keys()]
                    output_tuple = collections.namedtuple(self.python_interface.output_tuple_name, variables)
                    nones = [None for _ in self.python_interface.outputs.keys()]
                    return output_tuple(*nones)
                else:
                    return None
            # We are already in a local execution, just continue the execution context
            return self._local_execute(ctx, **input_kwargs)

        # Last is starting a local workflow execution
        else:
            # Run some sanity checks
            # Even though the _local_execute call generally expects inputs to be Promises, we don't have to do the
            # conversion here in this loop. The reason is because we don't prevent users from specifying inputs
            # as direct scalars, which means there's another Promise-generating loop inside _local_execute too
            for k, v in input_kwargs.items():
                if k not in self.interface.inputs:
                    raise ValueError(f"Received unexpected keyword argument {k}")
                if isinstance(v, Promise):
                    raise ValueError(f"Received a promise for a workflow call, when expecting a native value for {k}")

            result = None
            with FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION)
                )
            ) as child_ctx:
                result = self._local_execute(child_ctx, **input_kwargs)

            expected_outputs = len(self.python_interface.outputs)
            if expected_outputs == 0:
                if result is None or isinstance(result, VoidPromise):
                    return None
                else:
                    raise Exception(f"Workflow local execution expected 0 outputs but something received {result}")

            if (1 < expected_outputs == len(result)) or (result is not None and expected_outputs == 1):
                if isinstance(result, Promise):
                    v = [v for k, v in self.python_interface.outputs.items()][0]  # get output native type
                    return TypeEngine.to_python_value(ctx, result.val, v)
                else:
                    for prom in result:
                        if not isinstance(prom, Promise):
                            raise Exception("should be promises")
                        native_list = [
                            TypeEngine.to_python_value(ctx, promise.val, self.python_interface.outputs[promise.var])
                            for promise in result
                        ]
                        return tuple(native_list)

            raise ValueError("expected outputs and actual outputs do not match")
예제 #19
0
def setup_execution(
    raw_output_data_prefix: str,
    checkpoint_path: Optional[str] = None,
    prev_checkpoint: Optional[str] = None,
    dynamic_addl_distro: Optional[str] = None,
    dynamic_dest_dir: Optional[str] = None,
):
    """

    :param raw_output_data_prefix:
    :param checkpoint_path:
    :param prev_checkpoint:
    :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic
      task were to run, it should set fast serialize to true and use these values in FastSerializationSettings
    :param dynamic_dest_dir: See above.
    :return:
    """
    exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ")
    exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM")
    exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM")
    exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF")
    exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP")

    tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ")
    tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM")
    tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM")
    tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V")

    compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "")

    ctx = FlyteContextManager.current_context()
    # Create directories
    user_workspace_dir = ctx.file_access.get_random_local_directory()
    logger.info(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    checkpointer = None
    if checkpoint_path is not None:
        checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
        logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=exe_project,
            domain=exe_domain,
            name=exe_name,
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            cfg=StatsConfig.auto(),
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats",
            tags={
                "exec_project": exe_project,
                "exec_domain": exe_domain,
                "exec_workflow": exe_wf,
                "exec_launchplan": exe_lp,
                "api_version": _api_version,
            },
        ),
        logging=user_space_logger,
        tmp_dir=user_workspace_dir,
        raw_output_prefix=raw_output_data_prefix,
        checkpoint=checkpointer,
        task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
    )

    try:
        file_access = FileAccessProvider(
            local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
            raw_output_prefix=raw_output_data_prefix,
        )
    except TypeError:  # would be thrown from DataPersistencePlugins.find_plugin
        logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
        raise

    es = ctx.new_execution_state().with_params(
        mode=ExecutionState.Mode.TASK_EXECUTION,
        user_space_params=execution_parameters,
    )
    cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es)

    if compressed_serialization_settings:
        ss = SerializationSettings.from_transport(compressed_serialization_settings)
        ssb = ss.new_builder()
        ssb.project = exe_project
        ssb.domain = exe_domain
        ssb.version = tk_version
        if dynamic_addl_distro:
            ssb.fast_serialization_settings = FastSerializationSettings(
                enabled=True,
                destination_dir=dynamic_dest_dir,
                distribution_location=dynamic_addl_distro,
            )
        cb = cb.with_serialization_settings(ssb.build())

    with FlyteContextManager.with_context(cb) as ctx:
        yield ctx
예제 #20
0
def setup_execution(
    raw_output_data_prefix: str,
    dynamic_addl_distro: str = None,
    dynamic_dest_dir: str = None,
):
    cloud_provider = _platform_config.CLOUD_PROVIDER.get()
    log_level = _internal_config.LOGGING_LEVEL.get(
    ) or _sdk_config.LOGGING_LEVEL.get()
    _logging.getLogger().setLevel(log_level)

    ctx = FlyteContextManager.current_context()

    # Create directories
    user_workspace_dir = ctx.file_access.local_access.get_random_directory()
    _click.echo(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=_internal_config.EXECUTION_PROJECT.get(),
            domain=_internal_config.EXECUTION_DOMAIN.get(),
            name=_internal_config.EXECUTION_NAME.get(),
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            "{}.{}.{}.user_stats".format(
                _internal_config.TASK_PROJECT.get()
                or _internal_config.PROJECT.get(),
                _internal_config.TASK_DOMAIN.get()
                or _internal_config.DOMAIN.get(),
                _internal_config.TASK_NAME.get()
                or _internal_config.NAME.get(),
            ),
            tags={
                "exec_project": _internal_config.EXECUTION_PROJECT.get(),
                "exec_domain": _internal_config.EXECUTION_DOMAIN.get(),
                "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(),
                "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(),
                "api_version": _api_version,
            },
        ),
        logging=_logging,
        tmp_dir=user_workspace_dir,
    )

    if cloud_provider == _constants.CloudProvider.AWS:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(),
            remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.GCP:
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(),
            remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix),
        )
    elif cloud_provider == _constants.CloudProvider.LOCAL:
        # A fake remote using the local disk will automatically be created
        file_access = _data_proxy.FileAccessProvider(
            local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get())
    else:
        raise Exception(f"Bad cloud provider {cloud_provider}")

    with FlyteContextManager.with_context(
            ctx.with_file_access(file_access)) as ctx:
        # TODO: This is copied from serialize, which means there's a similarity here I'm not seeing.
        env = {
            _internal_config.CONFIGURATION_PATH.env_var:
            _internal_config.CONFIGURATION_PATH.get(),
            _internal_config.IMAGE.env_var:
            _internal_config.IMAGE.get(),
        }

        serialization_settings = SerializationSettings(
            project=_internal_config.TASK_PROJECT.get(),
            domain=_internal_config.TASK_DOMAIN.get(),
            version=_internal_config.TASK_VERSION.get(),
            image_config=get_image_config(),
            env=env,
        )

        # The reason we need this is because of dynamic tasks. Even if we move compilation all to Admin,
        # if a dynamic task calls some task, t1, we have to write to the DJ Spec the correct task
        # identifier for t1.
        with FlyteContextManager.with_context(
                ctx.with_serialization_settings(
                    serialization_settings)) as ctx:
            # Because execution states do not look up the context chain, it has to be made last
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(
                        ctx.new_execution_state().with_params(
                            mode=ExecutionState.Mode.TASK_EXECUTION,
                            user_space_params=execution_parameters,
                            additional_context={
                                "dynamic_addl_distro": dynamic_addl_distro,
                                "dynamic_dest_dir": dynamic_dest_dir,
                            },
                        ))) as ctx:
                yield ctx