コード例 #1
0
def get_entities_in_file(filename: str) -> Entities:
    """
    Returns a list of flyte workflow names and list of Flyte tasks in a file.
    """
    flyte_ctx = context_manager.FlyteContextManager.current_context(
    ).new_builder()
    module_name = os.path.splitext(os.path.relpath(filename))[0].replace(
        os.path.sep, ".")
    with context_manager.FlyteContextManager.with_context(flyte_ctx):
        with module_loader.add_sys_path(os.getcwd()):
            importlib.import_module(module_name)

    workflows = []
    tasks = []
    module = importlib.import_module(module_name)
    for k in dir(module):
        o = module.__dict__[k]
        if isinstance(o, PythonFunctionWorkflow):
            _, _, fn, _ = tracker.extract_task_module(o)
            workflows.append(fn)
        elif isinstance(o, PythonTask):
            _, _, fn, _ = tracker.extract_task_module(o)
            tasks.append(fn)

    return Entities(workflows, tasks)
コード例 #2
0
    def loader_args(self, settings: SerializationSettings,
                    task: PythonAutoContainerTask) -> List[str]:
        from flytekit.core.python_function_task import PythonFunctionTask

        if isinstance(task, PythonFunctionTask):
            _, m, t, _ = extract_task_module(task.task_function)
            return ["task-module", m, "task-name", t]
        if isinstance(task, TrackedInstance):
            _, m, t, _ = extract_task_module(task)
            return ["task-module", m, "task-name", t]
コード例 #3
0
def test_extract_task_module(test_input, expected):
    old = FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT
    FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT = "auto"
    try:
        # The last element is the full path of a local file, which is not stable across users / runs.
        assert extract_task_module(test_input)[:-1] == expected
    except Exception:
        FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT = old
        raise
コード例 #4
0
    def __init__(
        self,
        task_config: T,
        task_function: Callable,
        task_type="python-task",
        ignore_input_vars: Optional[List[str]] = None,
        execution_mode: Optional[ExecutionBehavior] = ExecutionBehavior.
        DEFAULT,
        task_resolver: Optional[TaskResolverMixin] = None,
        **kwargs,
    ):
        """
        :param T task_config: Configuration object for Task. Should be a unique type for that specific Task
        :param Callable task_function: Python function that has type annotations and works for the task
        :param Optional[List[str]] ignore_input_vars: When supplied, these input variables will be removed from the interface. This
                                  can be used to inject some client side variables only. Prefer using ExecutionParams
        :param Optional[ExecutionBehavior] execution_mode: Defines how the execution should behave, for example
            executing normally or specially handling a dynamic case.
        :param Optional[TaskResolverMixin] task_type: String task type to be associated with this Task
        """
        if task_function is None:
            raise ValueError(
                "TaskFunction is a required parameter for PythonFunctionTask")
        self._native_interface = transform_function_to_interface(
            task_function, Docstring(callable_=task_function))
        mutated_interface = self._native_interface.remove_inputs(
            ignore_input_vars)
        name, _, _, _ = extract_task_module(task_function)
        super().__init__(
            task_type=task_type,
            name=name,
            interface=mutated_interface,
            task_config=task_config,
            task_resolver=task_resolver,
            **kwargs,
        )

        if self._task_resolver is default_task_resolver:
            # The default task resolver can't handle nested functions
            # TODO: Consider moving this to a can_handle function or something inside the resolver itself.
            if (not istestfunction(func=task_function)
                    and isnested(func=task_function)
                    and not is_functools_wrapped_module_level(task_function)):
                raise ValueError(
                    "TaskFunction cannot be a nested/inner or local function. "
                    "It should be accessible at a module level for Flyte to execute it. Test modules with "
                    "names beginning with `test_` are allowed to have nested tasks. "
                    "If you're decorating your task function with custom decorators, use functools.wraps "
                    "or functools.update_wrapper on the function wrapper. "
                    "Alternatively if you want to create your own tasks with custom behavior use the TaskResolverMixin"
                )
        self._task_function = task_function
        self._execution_mode = execution_mode
コード例 #5
0
def test_dont_use_wrapper_location():
    m = importlib.import_module(
        "tests.flytekit.unit.core.flyte_functools.decorator_usage")
    get_data_task = getattr(m, "get_data")
    assert "decorator_source" not in get_data_task.name
    assert "decorator_usage" in get_data_task.name

    a, b, c, _ = extract_task_module(get_data_task)
    assert (a, b, c) == (
        "tests.flytekit.unit.core.flyte_functools.decorator_usage.get_data",
        "tests.flytekit.unit.core.flyte_functools.decorator_usage",
        "get_data",
    )
コード例 #6
0
ファイル: script_mode.py プロジェクト: flyteorg/flytekit
def fast_register_single_script(
    wf_entity: WorkflowBase, create_upload_location_fn: typing.Callable
) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes):
    _, mod_name, _, script_full_path = extract_task_module(wf_entity)
    # Find project root by moving up the folder hierarchy until you cannot find a __init__.py file.
    source_path = _find_project_root(script_full_path)

    # Open a temp directory and dump the contents of the digest.
    with tempfile.TemporaryDirectory() as tmp_dir:
        archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz")
        compress_single_script(source_path, archive_fname, mod_name)

        flyte_ctx = context_manager.FlyteContextManager.current_context()
        md5, _ = hash_file(archive_fname)
        upload_location = create_upload_location_fn(content_md5=md5)
        flyte_ctx.file_access.put_data(archive_fname,
                                       upload_location.signed_url)

        return upload_location, md5
コード例 #7
0
ファイル: map_task.py プロジェクト: flyteorg/flytekit
    def __init__(
        self,
        python_function_task: PythonFunctionTask,
        concurrency: int = None,
        min_success_ratio: float = None,
        **kwargs,
    ):
        """
        :param python_function_task: This argument is implicitly passed and represents the repeatable function
        :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given
        batch size
        :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete
            successfully before terminating this task and marking it successful.
        """
        if len(python_function_task.python_interface.inputs.keys()) > 1:
            raise ValueError("Map tasks only accept python function tasks with 0 or 1 inputs")

        if len(python_function_task.python_interface.outputs.keys()) > 1:
            raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs")

        collection_interface = transform_interface_to_list_interface(python_function_task.python_interface)
        instance = next(self._ids)
        _, mod, f, _ = tracker.extract_task_module(python_function_task.task_function)
        name = f"{mod}.mapper_{f}_{instance}"

        self._cmd_prefix = None
        self._run_task = python_function_task
        self._max_concurrency = concurrency
        self._min_success_ratio = min_success_ratio
        self._array_task_interface = python_function_task.python_interface
        if "metadata" not in kwargs and python_function_task.metadata:
            kwargs["metadata"] = python_function_task.metadata
        if "security_ctx" not in kwargs and python_function_task.security_context:
            kwargs["security_ctx"] = python_function_task.security_context
        super().__init__(
            name=name,
            interface=collection_interface,
            task_type=SdkTaskType.CONTAINER_ARRAY_TASK,
            task_config=None,
            task_type_version=1,
            **kwargs,
        )
コード例 #8
0
    def __init__(
        self,
        workflow_function: Callable,
        metadata: Optional[WorkflowMetadata],
        default_metadata: Optional[WorkflowMetadataDefaults],
        docstring: Docstring = None,
    ):
        name, _, _, _ = extract_task_module(workflow_function)
        self._workflow_function = workflow_function
        native_interface = transform_function_to_interface(workflow_function,
                                                           docstring=docstring)

        # TODO do we need this - can this not be in launchplan only?
        #    This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
        #    we can re-evaluate.
        self._input_parameters = None
        super().__init__(
            name=name,
            workflow_metadata=metadata,
            workflow_metadata_defaults=default_metadata,
            python_interface=native_interface,
        )