def __init__( self, task_type: str, name: str, task_config: T, interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, **kwargs, ): """ Args: task_type: a string that defines a unique task-type for every new extension. If a backend plugin is required then this has to be done in-concert with the backend plugin identifier name: A unique name for the task instantiation. This is unique for every instance of task. task_config: Configuration for the task. This is used to configure the specific plugin that handles this task interface: A python native typed interface ``(inputs) -> outputs`` that declares the signature of the task environment: Any environment variables that should be supplied during the execution of the task. Supplied as a dictionary of key/value pairs """ super().__init__( task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), **kwargs, ) self._python_interface = interface if interface else Interface() self._environment = environment if environment else {} self._task_config = task_config
def add_workflow_output( self, output_name: str, p: Union[Promise, List[Promise], Dict[str, Promise]], python_type: Optional[Type] = None ): """ Add an output with the given name from the given node output. """ if output_name in self._python_interface.outputs: raise FlyteValidationException(f"Output {output_name} already exists in workflow {self.name}") if python_type is None: if type(p) == list or type(p) == dict: raise FlyteValidationException( f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}" f" starting with the container type (e.g. List[int]" ) python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var] logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}") flyte_type = TypeEngine.to_literal_type(python_type=python_type) ctx = FlyteContext.current_context() if ctx.compilation_state is not None: raise Exception("Can't already be compiling") with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx: b = binding_from_python_std( ctx, output_name, expected_literal_type=flyte_type, t_value=p, t_value_type=python_type ) self._output_bindings.append(b) self._python_interface = self._python_interface.with_outputs(extra_outputs={output_name: python_type}) self._interface = transform_interface_to_typed_interface(self._python_interface)
def __init__( self, task_type: str, name: str, task_config: T, interface: Optional[Interface] = None, **kwargs, ): super().__init__( task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), **kwargs ) self._python_interface = interface if interface else Interface() self._environment = kwargs.get("environment", {}) self._task_config = task_config
def add_workflow_input(self, input_name: str, python_type: Type) -> Interface: """ Adds an input to the workflow. """ if input_name in self._inputs: raise FlyteValidationException(f"Input {input_name} has already been specified for wf {self.name}.") self._python_interface = self._python_interface.with_inputs(extra_inputs={input_name: python_type}) self._interface = transform_interface_to_typed_interface(self._python_interface) self._inputs[input_name] = Promise(var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name)) self._unbound_inputs.add(self._inputs[input_name]) return self._inputs[input_name]
def __init__( self, reference: Union[WorkflowReference, TaskReference, LaunchPlanReference], inputs: Optional[Dict[str, Union[Type[Any], Tuple[Type[Any], Any]]]], outputs: Dict[str, Type], ): if (not isinstance(reference, WorkflowReference) and not isinstance(reference, TaskReference) and not isinstance(reference, LaunchPlanReference)): raise Exception("Must be one of task, workflow, or launch plan") self._reference = reference self._native_interface = Interface(inputs=inputs, outputs=outputs) self._interface = transform_interface_to_typed_interface( self._native_interface)
def __init__( self, name: str, workflow_metadata: WorkflowMetadata, workflow_metadata_defaults: WorkflowMetadataDefaults, python_interface: Interface, **kwargs, ): self._name = name self._workflow_metadata = workflow_metadata self._workflow_metadata_defaults = workflow_metadata_defaults self._python_interface = python_interface self._interface = transform_interface_to_typed_interface(python_interface) self._inputs = {} self._unbound_inputs = set() self._nodes = [] self._output_bindings: Optional[List[_literal_models.Binding]] = [] FlyteEntities.entities.append(self) super().__init__(**kwargs)
def __init__( self, workflow_function: Callable, metadata: Optional[WorkflowMetadata], default_metadata: Optional[WorkflowMetadataDefaults], ): self._name = f"{workflow_function.__module__}.{workflow_function.__name__}" self._workflow_function = workflow_function self._native_interface = transform_signature_to_interface( inspect.signature(workflow_function)) self._interface = transform_interface_to_typed_interface( self._native_interface) # These will get populated on compile only self._nodes = None self._output_bindings: Optional[List[_literal_models.Binding]] = None self._workflow_metadata = metadata self._workflow_metadata_defaults = default_metadata # TODO do we need this - can this not be in launchplan only? # This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or # we can re-evaluate. self._input_parameters = None FlyteEntities.entities.append(self)
def create_and_link_node( ctx: FlyteContext, entity, interface: flyte_interface.Interface, timeout: Optional[datetime.timedelta] = None, retry_strategy: Optional[_literal_models.RetryStrategy] = None, **kwargs, ): """ This method is used to generate a node with bindings. This is not used in the execution path. """ if ctx.compilation_state is None: raise _user_exceptions.FlyteAssertion( "Cannot create node when not compiling...") used_inputs = set() bindings = [] typed_interface = flyte_interface.transform_interface_to_typed_interface( interface) for k in sorted(interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: raise _user_exceptions.FlyteAssertion( "Input was not specified for: {} of type {}".format( k, var.type)) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed # into the function. if isinstance(v, tuple): raise AssertionError( f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}." f" Check if the predecessor function returning more than one value?" ) bindings.append( binding_from_python_std(ctx, var_name=k, expected_literal_type=var.type, t_value=v, t_value_type=interface.inputs[k])) used_inputs.add(k) extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: raise _user_exceptions.FlyteAssertion( "Too many inputs were specified for the interface. Extra inputs were: {}" .format(extra_inputs)) # Detect upstream nodes # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes upstream_nodes = list( set([ input_val.ref.node for input_val in kwargs.values() if isinstance(input_val, Promise) and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID ])) node_metadata = _workflow_model.NodeMetadata( f"{entity.__module__}.{entity.name}", timeout or datetime.timedelta(), retry_strategy or _literal_models.RetryStrategy(0), ) non_sdk_node = Node( # TODO: Better naming, probably a derivative of the function name. id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", metadata=node_metadata, bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, flyte_entity=entity, ) ctx.compilation_state.add_node(non_sdk_node) if len(typed_interface.outputs) == 0: return VoidPromise(entity.name) # Create a node output object for each output, they should all point to this node of course. node_outputs = [] for output_name, output_var_model in typed_interface.outputs.items(): # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which # is currently just a static str node_outputs.append( Promise(output_name, NodeOutput(node=non_sdk_node, var=output_name))) # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break return create_task_output(node_outputs, interface)
def reset_interface(self, inputs: Dict[str, Type], outputs: Dict[str, Type]): self._interface = Interface(inputs=inputs, outputs=outputs) self._typed_interface = transform_interface_to_typed_interface(self._interface)
def test_transform_interface_to_typed_interface_with_docstring(): # sphinx style def z(a: int, b: str) -> typing.Tuple[int, str]: """ function z :param a: foo :param b: bar :return: ramen """ ... our_interface = transform_function_to_interface(z, Docstring(callable_=z)) typed_interface = transform_interface_to_typed_interface(our_interface) assert typed_interface.inputs.get("a").description == "foo" assert typed_interface.inputs.get("b").description == "bar" assert typed_interface.outputs.get("o1").description == "ramen" # numpy style, multiple return values, shared descriptions def z(a: int, b: str) -> typing.Tuple[int, str]: """ function z Parameters ---------- a : int foo b : str bar Returns ------- out1, out2 : tuple ramen """ ... our_interface = transform_function_to_interface(z, Docstring(callable_=z)) typed_interface = transform_interface_to_typed_interface(our_interface) assert typed_interface.inputs.get("a").description == "foo" assert typed_interface.inputs.get("b").description == "bar" assert typed_interface.outputs.get("o0").description == "ramen" assert typed_interface.outputs.get("o1").description == "ramen" # numpy style, multiple return values, named def z(a: int, b: str) -> typing.NamedTuple("NT", x_str=str, y_int=int): """ function z Parameters ---------- a : int foo b : str bar Returns ------- x_str : str description for x_str y_int : int description for y_int """ ... our_interface = transform_function_to_interface(z, Docstring(callable_=z)) typed_interface = transform_interface_to_typed_interface(our_interface) assert typed_interface.inputs.get("a").description == "foo" assert typed_interface.inputs.get("b").description == "bar" assert typed_interface.outputs.get( "x_str").description == "description for x_str" assert typed_interface.outputs.get( "y_int").description == "description for y_int"