예제 #1
0
 def __init__(
     self,
     task_type: str,
     name: str,
     task_config: T,
     interface: Optional[Interface] = None,
     environment: Optional[Dict[str, str]] = None,
     **kwargs,
 ):
     """
     Args:
         task_type: a string that defines a unique task-type for every new extension. If a backend plugin is required
                    then this has to be done in-concert with the backend plugin identifier
         name: A unique name for the task instantiation. This is unique for every instance of task.
         task_config: Configuration for the task. This is used to configure the specific plugin that handles this task
         interface: A python native typed interface ``(inputs) -> outputs`` that declares the signature of the task
         environment: Any environment variables that should be supplied during the execution of the task. Supplied as
                      a dictionary of key/value pairs
     """
     super().__init__(
         task_type=task_type,
         name=name,
         interface=transform_interface_to_typed_interface(interface),
         **kwargs,
     )
     self._python_interface = interface if interface else Interface()
     self._environment = environment if environment else {}
     self._task_config = task_config
예제 #2
0
    def add_workflow_output(
        self, output_name: str, p: Union[Promise, List[Promise], Dict[str, Promise]], python_type: Optional[Type] = None
    ):
        """
        Add an output with the given name from the given node output.
        """
        if output_name in self._python_interface.outputs:
            raise FlyteValidationException(f"Output {output_name} already exists in workflow {self.name}")

        if python_type is None:
            if type(p) == list or type(p) == dict:
                raise FlyteValidationException(
                    f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}"
                    f" starting with the container type (e.g. List[int]"
                )
            python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var]
            logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}")

        flyte_type = TypeEngine.to_literal_type(python_type=python_type)

        ctx = FlyteContext.current_context()
        if ctx.compilation_state is not None:
            raise Exception("Can't already be compiling")
        with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
            b = binding_from_python_std(
                ctx, output_name, expected_literal_type=flyte_type, t_value=p, t_value_type=python_type
            )
            self._output_bindings.append(b)
            self._python_interface = self._python_interface.with_outputs(extra_outputs={output_name: python_type})
            self._interface = transform_interface_to_typed_interface(self._python_interface)
예제 #3
0
 def __init__(
     self, task_type: str, name: str, task_config: T, interface: Optional[Interface] = None, **kwargs,
 ):
     super().__init__(
         task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), **kwargs
     )
     self._python_interface = interface if interface else Interface()
     self._environment = kwargs.get("environment", {})
     self._task_config = task_config
예제 #4
0
 def add_workflow_input(self, input_name: str, python_type: Type) -> Interface:
     """
     Adds an input to the workflow.
     """
     if input_name in self._inputs:
         raise FlyteValidationException(f"Input {input_name} has already been specified for wf {self.name}.")
     self._python_interface = self._python_interface.with_inputs(extra_inputs={input_name: python_type})
     self._interface = transform_interface_to_typed_interface(self._python_interface)
     self._inputs[input_name] = Promise(var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name))
     self._unbound_inputs.add(self._inputs[input_name])
     return self._inputs[input_name]
예제 #5
0
 def __init__(
     self,
     reference: Union[WorkflowReference, TaskReference,
                      LaunchPlanReference],
     inputs: Optional[Dict[str, Union[Type[Any], Tuple[Type[Any], Any]]]],
     outputs: Dict[str, Type],
 ):
     if (not isinstance(reference, WorkflowReference)
             and not isinstance(reference, TaskReference)
             and not isinstance(reference, LaunchPlanReference)):
         raise Exception("Must be one of task, workflow, or launch plan")
     self._reference = reference
     self._native_interface = Interface(inputs=inputs, outputs=outputs)
     self._interface = transform_interface_to_typed_interface(
         self._native_interface)
예제 #6
0
 def __init__(
     self,
     name: str,
     workflow_metadata: WorkflowMetadata,
     workflow_metadata_defaults: WorkflowMetadataDefaults,
     python_interface: Interface,
     **kwargs,
 ):
     self._name = name
     self._workflow_metadata = workflow_metadata
     self._workflow_metadata_defaults = workflow_metadata_defaults
     self._python_interface = python_interface
     self._interface = transform_interface_to_typed_interface(python_interface)
     self._inputs = {}
     self._unbound_inputs = set()
     self._nodes = []
     self._output_bindings: Optional[List[_literal_models.Binding]] = []
     FlyteEntities.entities.append(self)
     super().__init__(**kwargs)
예제 #7
0
    def __init__(
        self,
        workflow_function: Callable,
        metadata: Optional[WorkflowMetadata],
        default_metadata: Optional[WorkflowMetadataDefaults],
    ):
        self._name = f"{workflow_function.__module__}.{workflow_function.__name__}"
        self._workflow_function = workflow_function
        self._native_interface = transform_signature_to_interface(
            inspect.signature(workflow_function))
        self._interface = transform_interface_to_typed_interface(
            self._native_interface)
        # These will get populated on compile only
        self._nodes = None
        self._output_bindings: Optional[List[_literal_models.Binding]] = None
        self._workflow_metadata = metadata
        self._workflow_metadata_defaults = default_metadata

        # TODO do we need this - can this not be in launchplan only?
        #    This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
        #    we can re-evaluate.
        self._input_parameters = None
        FlyteEntities.entities.append(self)
예제 #8
0
def create_and_link_node(
    ctx: FlyteContext,
    entity,
    interface: flyte_interface.Interface,
    timeout: Optional[datetime.timedelta] = None,
    retry_strategy: Optional[_literal_models.RetryStrategy] = None,
    **kwargs,
):
    """
    This method is used to generate a node with bindings. This is not used in the execution path.
    """
    if ctx.compilation_state is None:
        raise _user_exceptions.FlyteAssertion(
            "Cannot create node when not compiling...")

    used_inputs = set()
    bindings = []

    typed_interface = flyte_interface.transform_interface_to_typed_interface(
        interface)

    for k in sorted(interface.inputs):
        var = typed_interface.inputs[k]
        if k not in kwargs:
            raise _user_exceptions.FlyteAssertion(
                "Input was not specified for: {} of type {}".format(
                    k, var.type))
        v = kwargs[k]
        # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
        # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
        # into the function.
        if isinstance(v, tuple):
            raise AssertionError(
                f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}."
                f" Check if the predecessor function returning more than one value?"
            )
        bindings.append(
            binding_from_python_std(ctx,
                                    var_name=k,
                                    expected_literal_type=var.type,
                                    t_value=v,
                                    t_value_type=interface.inputs[k]))
        used_inputs.add(k)

    extra_inputs = used_inputs ^ set(kwargs.keys())
    if len(extra_inputs) > 0:
        raise _user_exceptions.FlyteAssertion(
            "Too many inputs were specified for the interface.  Extra inputs were: {}"
            .format(extra_inputs))

    # Detect upstream nodes
    # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
    upstream_nodes = list(
        set([
            input_val.ref.node for input_val in kwargs.values()
            if isinstance(input_val, Promise)
            and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID
        ]))

    node_metadata = _workflow_model.NodeMetadata(
        f"{entity.__module__}.{entity.name}",
        timeout or datetime.timedelta(),
        retry_strategy or _literal_models.RetryStrategy(0),
    )

    non_sdk_node = Node(
        # TODO: Better naming, probably a derivative of the function name.
        id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
        metadata=node_metadata,
        bindings=sorted(bindings, key=lambda b: b.var),
        upstream_nodes=upstream_nodes,
        flyte_entity=entity,
    )
    ctx.compilation_state.add_node(non_sdk_node)

    if len(typed_interface.outputs) == 0:
        return VoidPromise(entity.name)

    # Create a node output object for each output, they should all point to this node of course.
    node_outputs = []
    for output_name, output_var_model in typed_interface.outputs.items():
        # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which
        #  is currently just a static str
        node_outputs.append(
            Promise(output_name, NodeOutput(node=non_sdk_node,
                                            var=output_name)))
        # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break

    return create_task_output(node_outputs, interface)
예제 #9
0
 def reset_interface(self, inputs: Dict[str, Type], outputs: Dict[str, Type]):
     self._interface = Interface(inputs=inputs, outputs=outputs)
     self._typed_interface = transform_interface_to_typed_interface(self._interface)
예제 #10
0
def test_transform_interface_to_typed_interface_with_docstring():
    # sphinx style
    def z(a: int, b: str) -> typing.Tuple[int, str]:
        """
        function z

        :param a: foo
        :param b: bar
        :return: ramen
        """
        ...

    our_interface = transform_function_to_interface(z, Docstring(callable_=z))
    typed_interface = transform_interface_to_typed_interface(our_interface)
    assert typed_interface.inputs.get("a").description == "foo"
    assert typed_interface.inputs.get("b").description == "bar"
    assert typed_interface.outputs.get("o1").description == "ramen"

    # numpy style, multiple return values, shared descriptions
    def z(a: int, b: str) -> typing.Tuple[int, str]:
        """
        function z

        Parameters
        ----------
        a : int
            foo
        b : str
            bar

        Returns
        -------
        out1, out2 : tuple
            ramen
        """
        ...

    our_interface = transform_function_to_interface(z, Docstring(callable_=z))
    typed_interface = transform_interface_to_typed_interface(our_interface)
    assert typed_interface.inputs.get("a").description == "foo"
    assert typed_interface.inputs.get("b").description == "bar"
    assert typed_interface.outputs.get("o0").description == "ramen"
    assert typed_interface.outputs.get("o1").description == "ramen"

    # numpy style, multiple return values, named
    def z(a: int, b: str) -> typing.NamedTuple("NT", x_str=str, y_int=int):
        """
        function z

        Parameters
        ----------
        a : int
            foo
        b : str
            bar

        Returns
        -------
        x_str : str
            description for x_str
        y_int : int
            description for y_int
        """
        ...

    our_interface = transform_function_to_interface(z, Docstring(callable_=z))
    typed_interface = transform_interface_to_typed_interface(our_interface)
    assert typed_interface.inputs.get("a").description == "foo"
    assert typed_interface.inputs.get("b").description == "bar"
    assert typed_interface.outputs.get(
        "x_str").description == "description for x_str"
    assert typed_interface.outputs.get(
        "y_int").description == "description for y_int"