コード例 #1
0
 def end_branch(self) -> Union[Condition, Promise]:
     """
     This should be invoked after every branch has been visited
     """
     if self._last_case:
         FlyteContextManager.pop_context()
     return self._condition
コード例 #2
0
ファイル: types.py プロジェクト: fediazgon/flytekit
    def __init__(
        self,
        local_path: os.PathLike = None,
        remote_path: str = None,
        supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE,
        downloader: typing.Callable[[str, os.PathLike], None] = None,
    ):

        if supported_mode == SchemaOpenMode.READ and remote_path is None:
            raise ValueError(
                "To create a FlyteSchema in read mode, remote_path is required"
            )
        if (supported_mode == SchemaOpenMode.WRITE and local_path is None
                and FlyteContextManager.current_context().file_access is None):
            raise ValueError(
                "To create a FlyteSchema in write mode, local_path is required"
            )

        if local_path is None:
            local_path = FlyteContextManager.current_context(
            ).file_access.get_random_local_directory()
        self._local_path = local_path
        self._remote_path = remote_path
        self._supported_mode = supported_mode
        # This is a special attribute that indicates if the data was either downloaded or uploaded
        self._downloaded = False
        self._downloader = downloader
コード例 #3
0
def test_two(two_sample_inputs):
    my_input = two_sample_inputs[0]
    my_input_2 = two_sample_inputs[1]

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteFile]:
        x = []
        for aa in a:
            x.append(aa.main_product)
        return x

    with FlyteContextManager.with_context(
        FlyteContextManager.current_context().with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ) as ctx:
        with FlyteContextManager.with_context(
            ctx.with_execution_state(
                ctx.execution_state.with_params(
                    mode=ExecutionState.Mode.TASK_EXECUTION,
                )
            )
        ) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]}
            )
            dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
            assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
コード例 #4
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContextManager.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            updated_exec_state = ctx.execution_state.with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(updated_exec_state)):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return exception_scopes.user_entry_point(task_function)(
                    **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            return self.compile_into_workflow(ctx, task_function, **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
            return exception_scopes.user_entry_point(task_function)(**kwargs)

        raise ValueError(
            f"Invalid execution provided, execution state: {ctx.execution_state}"
        )
コード例 #5
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the _local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContextManager.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            updated_exec_state = ctx.execution_state.with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(updated_exec_state)):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return task_function(**kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            is_fast_execution = bool(
                ctx.execution_state and ctx.execution_state.additional_context
                and ctx.execution_state.additional_context.get(
                    "dynamic_addl_distro"))
            if is_fast_execution:
                ctx = ctx.with_serialization_settings(
                    SerializationSettings.new_builder(
                    ).with_fast_serialization_settings(
                        FastSerializationSettings(enabled=True)).build())

            return self.compile_into_workflow(ctx, task_function, **kwargs)
コード例 #6
0
def serialize(
    pkgs: typing.List[str],
    settings: SerializationSettings,
    local_source_root: typing.Optional[str] = None,
    options: typing.Optional[Options] = None,
) -> typing.List[RegistrableEntity]:
    """
    See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the
    entity type.
    :param options:
    :param settings: SerializationSettings to be used
    :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization.
    :param local_source_root: Where to start looking for the code.
    """

    ctx = FlyteContextManager.current_context().with_serialization_settings(
        settings)
    with FlyteContextManager.with_context(ctx) as ctx:
        # Scan all modules. the act of loading populates the global singleton that contains all objects
        with module_loader.add_sys_path(local_source_root):
            click.secho(
                f"Loading packages {pkgs} under source root {local_source_root}",
                fg="yellow")
            module_loader.just_load_modules(pkgs=pkgs)

        registrable_entities = get_registrable_entities(ctx, options=options)
        click.secho(
            f"Successfully serialized {len(registrable_entities)} flyte objects",
            fg="green")
        return registrable_entities
コード例 #7
0
    def __init__(
        self,
        local_path: os.PathLike = None,
        remote_path: os.PathLike = None,
        supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE,
        downloader: typing.Callable[[str, os.PathLike], None] = None,
    ):

        if supported_mode == SchemaOpenMode.READ and remote_path is None:
            raise ValueError(
                "To create a FlyteSchema in read mode, remote_path is required"
            )
        if (supported_mode == SchemaOpenMode.WRITE and local_path is None
                and FlyteContextManager.current_context().file_access is None):
            raise ValueError(
                "To create a FlyteSchema in write mode, local_path is required"
            )

        local_path = local_path or FlyteContextManager.current_context(
        ).file_access.get_random_local_directory()
        self._local_path = local_path
        # Make this field public, so that the dataclass transformer can set a value for it
        # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
        self.remote_path = remote_path or FlyteContextManager.current_context(
        ).file_access.get_random_remote_path()
        self._supported_mode = supported_mode
        # This is a special attribute that indicates if the data was either downloaded or uploaded
        self._downloaded = False
        self._downloader = downloader
コード例 #8
0
 def __init__(self, name: str):
     self._name = name
     self._cases: typing.List[Case] = []
     self._last_case = False
     self._condition = Condition(self)
     ctx = FlyteContextManager.current_context()
     # A new conditional section has been started, so lets push the context
     FlyteContextManager.push_context(
         ctx.enter_conditional_section().build())
コード例 #9
0
def test_dc_dyn_directory(folders_and_files_setup):
    proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
    proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

    my_input_gcs = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/one"),
            external_data_dir=FlyteDirectory("gs://my-bucket/two"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    my_input_gcs_2 = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/three"),
            external_data_dir=FlyteDirectory("gs://my-bucket/four"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteDirectory]:
        x = []
        for aa in a:
            x.append(aa.apriori_config.external_data_dir)

        return x

    ctx = FlyteContextManager.current_context()
    cb = (
        ctx.new_builder()
        .with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
        .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
    )
    with FlyteContextManager.with_context(cb) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(
            ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]}
        )
        dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
        assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two"
        assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
コード例 #10
0
def test_levels():
    ctx = FlyteContextManager.current_context()
    b = ctx.new_builder()
    b.flyte_client = SampleTestClass(value=1)
    with FlyteContextManager.with_context(b) as outer:
        assert outer.flyte_client.value == 1
        b = outer.new_builder()
        b.flyte_client = SampleTestClass(value=2)
        with FlyteContextManager.with_context(b) as ctx:
            assert ctx.flyte_client.value == 2

        with FlyteContextManager.with_context(outer.with_new_compilation_state()) as ctx:
            assert ctx.flyte_client.value == 1
コード例 #11
0
 def end_branch(
     self
 ) -> Optional[Union[Condition, Tuple[Promise], Promise, VoidPromise]]:
     """
     This should be invoked after every branch has been visited
     """
     if self._last_case:
         FlyteContextManager.pop_context()
         curr = self.compute_output_vars()
         if curr is None:
             return VoidPromise(self.name)
         promises = [Promise(var=x, val=None) for x in curr]
         return create_task_output(promises)
     return self._condition
コード例 #12
0
def test_fill_in_literal_type():
    class TempEncoder(StructuredDatasetEncoder):
        def __init__(self, fmt: str):
            super().__init__(MyDF, "tmpfs://", supported_format=fmt)

        def encode(
            self,
            ctx: FlyteContext,
            structured_dataset: StructuredDataset,
            structured_dataset_type: StructuredDatasetType,
        ) -> literals.StructuredDataset:
            return literals.StructuredDataset(uri="")

    StructuredDatasetTransformerEngine.register(TempEncoder("myavro"),
                                                default_for_type=True)
    lt = TypeEngine.to_literal_type(MyDF)
    assert lt.structured_dataset_type.format == "myavro"

    ctx = FlyteContextManager.current_context()
    fdt = StructuredDatasetTransformerEngine()
    sd = StructuredDataset(dataframe=42)
    l = fdt.to_literal(ctx, sd, MyDF, lt)
    # Test that the literal type is filled in even though the encode function above doesn't do it.
    assert l.scalar.structured_dataset.metadata.structured_dataset_type.format == "myavro"

    # Test that looking up encoders/decoders falls back to the "" encoder/decoder
    empty_format_temp_encoder = TempEncoder("")
    StructuredDatasetTransformerEngine.register(empty_format_temp_encoder,
                                                default_for_type=False)

    res = StructuredDatasetTransformerEngine.get_encoder(
        MyDF, "tmpfs", "rando")
    assert res is empty_format_temp_encoder
コード例 #13
0
ファイル: workflow.py プロジェクト: dylanwilder/flytekit
    def add_workflow_output(
        self, output_name: str, p: Union[Promise, List[Promise], Dict[str, Promise]], python_type: Optional[Type] = None
    ):
        """
        Add an output with the given name from the given node output.
        """
        if output_name in self._python_interface.outputs:
            raise FlyteValidationException(f"Output {output_name} already exists in workflow {self.name}")

        if python_type is None:
            if type(p) == list or type(p) == dict:
                raise FlyteValidationException(
                    f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}"
                    f" starting with the container type (e.g. List[int]"
                )
            python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var]
            logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}")

        flyte_type = TypeEngine.to_literal_type(python_type=python_type)

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            raise Exception("Can't already be compiling")
        with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
            b = binding_from_python_std(
                ctx, output_name, expected_literal_type=flyte_type, t_value=p, t_value_type=python_type
            )
            self._output_bindings.append(b)
            self._python_interface = self._python_interface.with_outputs(extra_outputs={output_name: python_type})
            self._interface = transform_interface_to_typed_interface(self._python_interface)
コード例 #14
0
ファイル: workflow.py プロジェクト: dylanwilder/flytekit
    def add_entity(self, entity: PythonAutoContainerTask, **kwargs) -> Node:
        """
        Anytime you add an entity, all the inputs to the entity must be bound.
        """
        # circular import
        from flytekit.core.node_creation import create_node

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            raise Exception("Can't already be compiling")
        with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
            n = create_node(entity=entity, **kwargs)

            def get_input_values(input_value):
                if isinstance(input_value, list):
                    input_promises = []
                    for x in input_value:
                        input_promises.extend(get_input_values(x))
                    return input_promises
                if isinstance(input_value, dict):
                    input_promises = []
                    for _, v in input_value.items():
                        input_promises.extend(get_input_values(v))
                    return input_promises
                else:
                    return [input_value]

            # Every time an entity is added, mark it as used.
            for input_value in get_input_values(kwargs):
                if input_value in self._unbound_inputs:
                    self._unbound_inputs.remove(input_value)
            return n
コード例 #15
0
def test_to_python_value_without_incoming_columns():
    # make a literal with a type with no columns
    ctx = FlyteContextManager.current_context()
    lt = TypeEngine.to_literal_type(pd.DataFrame)
    df = generate_pandas()
    fdt = StructuredDatasetTransformerEngine()
    lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
    assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.
               columns) == 0

    # declare a new type that only has one column
    # get the dataframe, make sure it has the column that was asked for.
    subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)]
    sd = fdt.to_python_value(ctx, lit, subset_sd_type)
    assert sd.metadata.structured_dataset_type.columns[0].name == "age"
    sub_df = sd.open(pd.DataFrame).all()
    assert sub_df.shape[1] == 1

    # check when columns are not specified, should pull both and add column information.
    # todo: see the todos in the open_as, and iter_as functions in StructuredDatasetTransformerEngine
    #  we have to recreate the literal because the test case above filled in the metadata
    lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
    sd = fdt.to_python_value(ctx, lit, StructuredDataset)
    assert sd.metadata.structured_dataset_type.columns == []
    sub_df = sd.open(pd.DataFrame).all()
    assert sub_df.shape[1] == 2

    # should also work if subset type is just an annotated pd.DataFrame
    lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
    subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)]
    sub_df = fdt.to_python_value(ctx, lit, subset_pd_type)
    assert sub_df.shape[1] == 1
コード例 #16
0
    def inputs(self) -> Dict[str, Any]:
        """
        Returns the inputs to the execution in the standard python format as dictated by the type engine.
        """
        if self._inputs is None:
            client = _flyte_engine.get_client()
            execution_data = client.get_execution_data(self.id)

            # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
            input_map: _literal_models.LiteralMap = _literal_models.LiteralMap(
                {})
            if bool(execution_data.full_inputs.literals):
                input_map = execution_data.full_inputs
            elif execution_data.inputs.bytes > 0:
                with _common_utils.AutoDeletingTempDir() as tmp_dir:
                    tmp_name = _os.path.join(tmp_dir.name, "inputs.pb")
                    _data_proxy.Data.get_data(execution_data.inputs.url,
                                              tmp_name)
                    input_map = _literal_models.LiteralMap.from_flyte_idl(
                        _common_utils.load_proto_from_file(
                            _literals_pb2.Literalmap, tmp_name))
            lp_id = self.spec.launch_plan
            workflow = _workflow.FlyteWorkflow.fetch(lp_id.project,
                                                     lp_id.domain, lp_id.name,
                                                     lp_id.version)
            self._inputs = TypeEngine.literal_map_to_kwargs(
                ctx=FlyteContextManager.current_context(),
                lm=input_map,
                python_types=TypeEngine.guess_python_types(
                    workflow.interface.inputs),
            )
        return self._inputs
コード例 #17
0
    def save(self, cp: typing.Union[Path, str, io.BufferedReader]):
        # We have to lazy load, until we fix the imports
        from flytekit.core.context_manager import FlyteContextManager

        fa = FlyteContextManager.current_context().file_access
        if isinstance(cp, (Path, str)):
            if isinstance(cp, str):
                cp = Path(cp)
            if cp.is_dir():
                fa.upload_directory(str(cp), self._checkpoint_dest)
            else:
                fname = cp.stem + cp.suffix
                rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, fname)
                fa.upload(str(cp), rpath)
            return

        if not isinstance(cp, io.IOBase):
            raise ValueError(f"Only a valid path or IOBase type (reader) should be provided, received {type(cp)}")

        p = Path(self._td.name)
        dest_cp = p.joinpath(self.TMP_DST_PATH)
        with dest_cp.open("wb") as f:
            f.write(cp.read())

        rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, self.TMP_DST_PATH)
        fa.upload(str(dest_cp), rpath)
コード例 #18
0
def test_to_python_value_with_incoming_columns():
    # make a literal with a type that has two columns
    original_type = Annotated[pd.DataFrame, kwtypes(name=str, age=int)]
    ctx = FlyteContextManager.current_context()
    lt = TypeEngine.to_literal_type(original_type)
    df = generate_pandas()
    fdt = StructuredDatasetTransformerEngine()
    lit = fdt.to_literal(ctx, df, python_type=original_type, expected=lt)
    assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.
               columns) == 2

    # declare a new type that only has one column
    # get the dataframe, make sure it has the column that was asked for.
    subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)]
    sd = fdt.to_python_value(ctx, lit, subset_sd_type)
    assert sd.metadata.structured_dataset_type.columns[0].name == "age"
    sub_df = sd.open(pd.DataFrame).all()
    assert sub_df.shape[1] == 1

    # check when columns are not specified, should pull both and add column information.
    sd = fdt.to_python_value(ctx, lit, StructuredDataset)
    assert len(sd.metadata.structured_dataset_type.columns) == 2

    # should also work if subset type is just an annotated pd.DataFrame
    subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)]
    sub_df = fdt.to_python_value(ctx, lit, subset_pd_type)
    assert sub_df.shape[1] == 1
コード例 #19
0
    def inputs(self) -> Dict[str, Any]:
        """
        Returns the inputs of the task execution in the standard Python format that is produced by
        the type engine.
        """
        from flytekit.control_plane.tasks.task import FlyteTask

        if self._inputs is None:
            client = _flyte_engine.get_client()
            execution_data = client.get_task_execution_data(self.id)

            # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned.
            input_map = _literal_models.LiteralMap({})
            if bool(execution_data.full_inputs.literals):
                input_map = execution_data.full_inputs
            elif execution_data.inputs.bytes > 0:
                with _common_utils.AutoDeletingTempDir() as tmp_dir:
                    tmp_name = os.path.join(tmp_dir.name, "inputs.pb")
                    _data_proxy.Data.get_data(execution_data.inputs.url,
                                              tmp_name)
                    input_map = _literal_models.LiteralMap.from_flyte_idl(
                        _common_utils.load_proto_from_file(
                            _literals_pb2.LiteralMap, tmp_name))

            task = FlyteTask.fetch(self.id.task_id.project,
                                   self.id.task_id.domain,
                                   self.id.task_id.name,
                                   self.id.task_id.version)
            self._inputs = TypeEngine.literal_map_to_kwargs(
                ctx=FlyteContextManager.current_context(),
                lm=input_map,
                python_types=TypeEngine.guess_python_types(
                    task.interface.inputs),
            )
        return self._inputs
コード例 #20
0
def test_enum_type():
    t = TypeEngine.to_literal_type(Color)
    assert t is not None
    assert t.enum_type is not None
    assert t.enum_type.values
    assert t.enum_type.values == [c.value for c in Color]

    ctx = FlyteContextManager.current_context()
    lv = TypeEngine.to_literal(ctx, Color.RED, Color, TypeEngine.to_literal_type(Color))
    assert lv
    assert lv.scalar
    assert lv.scalar.primitive.string_value == "red"

    v = TypeEngine.to_python_value(ctx, lv, Color)
    assert v
    assert v == Color.RED

    v = TypeEngine.to_python_value(ctx, lv, str)
    assert v
    assert v == "red"

    with pytest.raises(ValueError):
        TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value=str(Color.RED)))), Color)

    with pytest.raises(ValueError):
        TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value="bad"))), Color)

    with pytest.raises(AssertionError):
        TypeEngine.to_literal_type(UnsupportedEnumValues)
コード例 #21
0
    def create(
        cls,
        name: str,
        workflow: _annotated_workflow.WorkflowBase,
        default_inputs: Dict[str, Any] = None,
        fixed_inputs: Dict[str, Any] = None,
        schedule: _schedule_model.Schedule = None,
        notifications: List[_common_models.Notification] = None,
        auth_role: _common_models.AuthRole = None,
    ) -> LaunchPlan:
        ctx = FlyteContextManager.current_context()
        default_inputs = default_inputs or {}
        fixed_inputs = fixed_inputs or {}
        # Default inputs come from two places, the original signature of the workflow function, and the default_inputs
        # argument to this function. We'll take the latter as having higher precedence.
        wf_signature_parameters = transform_inputs_to_parameters(
            ctx, workflow.python_interface)

        # Construct a new Interface object with just the default inputs given to get Parameters, maybe there's an
        # easier way to do this, think about it later.
        temp_inputs = {}
        for k, v in default_inputs.items():
            temp_inputs[k] = (workflow.python_interface.inputs[k], v)
        temp_interface = Interface(inputs=temp_inputs, outputs={})
        temp_signature = transform_inputs_to_parameters(ctx, temp_interface)
        wf_signature_parameters._parameters.update(temp_signature.parameters)

        # These are fixed inputs that cannot change at launch time. If the same argument is also in default inputs,
        # it'll be taken out from defaults in the LaunchPlan constructor
        fixed_literals = translate_inputs_to_literals(
            ctx,
            incoming_values=fixed_inputs,
            flyte_interface_types=workflow.interface.inputs,
            native_types=workflow.python_interface.inputs,
        )
        fixed_lm = _literal_models.LiteralMap(literals=fixed_literals)

        lp = cls(
            name=name,
            workflow=workflow,
            parameters=wf_signature_parameters,
            fixed_inputs=fixed_lm,
            schedule=schedule,
            notifications=notifications,
            auth_role=auth_role,
        )

        # This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out
        # to protobuf, but for local execution and such, why not save the original Python native values as well so
        # we don't have to reverse it back every time.
        default_inputs.update(fixed_inputs)
        lp._saved_inputs = default_inputs

        if name in cls.CACHE:
            raise AssertionError(
                f"Launch plan named {name} was already created! Make sure your names are unique."
            )
        cls.CACHE[name] = lp
        return lp
コード例 #22
0
    def __call__(self, *args, **kwargs):
        # When a Task is () aka __called__, there are three things we may do:
        #  a. Task Execution Mode - just run the Python function as Python normally would. Flyte steps completely
        #     out of the way.
        #  b. Compilation Mode - this happens when the function is called as part of a workflow (potentially
        #     dynamic task?). Instead of running the user function, produce promise objects and create a node.
        #  c. Workflow Execution Mode - when a workflow is being run locally. Even though workflows are functions
        #     and everything should be able to be passed through naturally, we'll want to wrap output values of the
        #     function into objects, so that potential .with_cpu or other ancillary functions can be attached to do
        #     nothing. Subsequent tasks will have to know how to unwrap these. If by chance a non-Flyte task uses a
        #     task output as an input, things probably will fail pretty obviously.
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                f"When calling tasks, only keyword args are supported. "
                f"Aborting execution as detected {len(args)} positional args {args}"
            )

        ctx = FlyteContextManager.current_context()
        if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
            return self.compile(ctx, *args, **kwargs)
        elif (ctx.execution_state is not None and ctx.execution_state.mode
              == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION):
            if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
                if self.python_interface and self.python_interface.output_tuple_name:
                    variables = [
                        k for k in self.python_interface.outputs.keys()
                    ]
                    output_tuple = collections.namedtuple(
                        self.python_interface.output_tuple_name, variables)
                    nones = [
                        None for _ in self.python_interface.outputs.keys()
                    ]
                    return output_tuple(*nones)
                else:
                    # Should we return multiple None's here?
                    return None
            return self._local_execute(ctx, **kwargs)
        else:
            logger.warning("task run without context - executing raw function")
            new_user_params = self.pre_execute(ctx.user_space_params)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(
                        ctx.execution_state.with_params(
                            mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION,
                            user_space_params=new_user_params))):
                return self.execute(**kwargs)
コード例 #23
0
 def all(self) -> DF:
     if self._dataframe_type is None:
         raise ValueError(
             "No dataframe type set. Use open() to set the local dataframe type you want to use."
         )
     ctx = FlyteContextManager.current_context()
     return flyte_dataset_transformer.open_as(ctx, self.literal,
                                              self._dataframe_type,
                                              self.metadata)
コード例 #24
0
 def iter(self) -> Generator[DF, None, None]:
     if self._dataframe_type is None:
         raise ValueError(
             "No dataframe type set. Use open() to set the local dataframe type you want to use."
         )
     ctx = FlyteContextManager.current_context()
     return flyte_dataset_transformer.iter_as(
         ctx,
         self.literal,
         self._dataframe_type,
         updated_metadata=self.metadata)
コード例 #25
0
ファイル: test_protobuf.py プロジェクト: flyteorg/flytekit
def test_pb_guess_python_type():
    artifact_tag = catalog_pb2.CatalogArtifactTag(artifact_id="artifact_1", name="artifact_name")

    x = {"a": artifact_tag}
    lt = TypeEngine.to_literal_type(catalog_pb2.CatalogArtifactTag)
    gt = TypeEngine.guess_python_type(lt)
    assert gt == catalog_pb2.CatalogArtifactTag
    ctx = FlyteContextManager.current_context()
    lm = TypeEngine.dict_to_literal_map(ctx, x, {"a": gt})
    pv = TypeEngine.to_python_value(ctx, lm.literals["a"], gt)
    assert pv == artifact_tag
コード例 #26
0
ファイル: map_task.py プロジェクト: dylanwilder/flytekit
 def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]:
     """
     We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this
     interface to construct outputs. Each instance of an container_array task will however produce outputs
     according to the underlying run_task interface and the array plugin handler will actually create a collection
     from these individual outputs as the final output value.
     """
     ctx = FlyteContextManager.current_context()
     if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
         # In workflow execution mode we actually need to use the parent (mapper) task output interface.
         return self._python_interface.outputs[k]
     return self._run_task._python_interface.outputs[k]
コード例 #27
0
    def t1() -> FlyteDirectory:
        user_ctx = FlyteContextManager.current_context().user_space_params
        # Create a local directory to work with
        p = os.path.join(user_ctx.working_directory, "test_wf")
        if os.path.exists(p):
            shutil.rmtree(p)
        pathlib.Path(p).mkdir(parents=True)
        for i in range(1, 6):
            with open(os.path.join(p, f"{i}.txt"), "w") as fh:
                fh.write(f"I'm file {i}\n")

        return FlyteDirectory(p)
コード例 #28
0
    def end_branch(
        self
    ) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidPromise]]:
        """
        This should be invoked after every branch has been visited.
        In case this is not local workflow execution then, we should check if this is the last case.
        If so then return the promise, else return the condition
        """
        if self._last_case:
            # We have completed the conditional section, lets pop off the branch context
            FlyteContextManager.pop_context()
            ctx = FlyteContextManager.current_context()
            # Question: This is commented out because we don't need it? Nodes created in the conditional
            #   compilation state are captured in the to_case_block? Always?
            #   Is this still true of nested conditionals? Is that why propeller compiler is complaining?
            # branch_nodes = ctx.compilation_state.nodes
            node, promises = to_branch_node(self._name, self)
            # Verify branch_nodes == nodes in bn
            bindings: typing.List[Binding] = []
            upstream_nodes = set()
            for p in promises:
                if not p.is_ready:
                    bindings.append(
                        Binding(var=p.var, binding=BindingData(promise=p.ref)))
                    upstream_nodes.add(p.ref.node)

            n = Node(
                id=
                f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",  # type: ignore
                metadata=_core_wf.NodeMetadata(self._name,
                                               timeout=datetime.timedelta(),
                                               retries=RetryStrategy(0)),
                bindings=sorted(bindings, key=lambda b: b.var),
                upstream_nodes=list(upstream_nodes),  # type: ignore
                flyte_entity=node,
            )
            FlyteContextManager.current_context().compilation_state.add_node(
                n)  # type: ignore
            return self._compute_outputs(n)
        return self._condition
コード例 #29
0
def test_format_correct():
    class TempEncoder(StructuredDatasetEncoder):
        def __init__(self):
            super().__init__(pd.DataFrame, S3, "avro")

        def encode(
            self,
            ctx: FlyteContext,
            structured_dataset: StructuredDataset,
            structured_dataset_type: StructuredDatasetType,
        ) -> literals.StructuredDataset:
            return literals.StructuredDataset(
                uri="/tmp/avro",
                metadata=StructuredDatasetMetadata(structured_dataset_type))

    ctx = FlyteContextManager.current_context()
    df = pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})

    annotated_sd_type = Annotated[StructuredDataset, "avro",
                                  kwtypes(name=str, age=int)]
    df_literal_type = TypeEngine.to_literal_type(annotated_sd_type)
    assert df_literal_type.structured_dataset_type is not None
    assert len(df_literal_type.structured_dataset_type.columns) == 2
    assert df_literal_type.structured_dataset_type.columns[0].name == "name"
    assert df_literal_type.structured_dataset_type.columns[
        0].literal_type.simple is not None
    assert df_literal_type.structured_dataset_type.columns[1].name == "age"
    assert df_literal_type.structured_dataset_type.columns[
        1].literal_type.simple is not None
    assert df_literal_type.structured_dataset_type.format == "avro"

    sd = annotated_sd_type(df)
    with pytest.raises(ValueError):
        TypeEngine.to_literal(ctx,
                              sd,
                              python_type=annotated_sd_type,
                              expected=df_literal_type)

    StructuredDatasetTransformerEngine.register(TempEncoder(),
                                                default_for_type=False)
    sd2 = annotated_sd_type(df)
    sd_literal = TypeEngine.to_literal(ctx,
                                       sd2,
                                       python_type=annotated_sd_type,
                                       expected=df_literal_type)
    assert sd_literal.scalar.structured_dataset.metadata.structured_dataset_type.format == "avro"

    @task
    def t1() -> Annotated[StructuredDataset, "avro"]:
        return StructuredDataset(dataframe=df)

    assert t1().file_format == "avro"
コード例 #30
0
ファイル: map_task.py プロジェクト: dylanwilder/flytekit
    def _outputs_interface(self) -> Dict[Any, Variable]:
        """
        We override this method from PythonTask because the dispatch_execute method uses this
        interface to construct outputs. Each instance of an container_array task will however produce outputs
        according to the underlying run_task interface and the array plugin handler will actually create a collection
        from these individual outputs as the final output value.
        """

        ctx = FlyteContextManager.current_context()
        if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            # In workflow execution mode we actually need to use the parent (mapper) task output interface.
            return self.interface.outputs
        return self._run_task.interface.outputs