def test_parameters_and_defaults(): ctx = context_manager.FlyteContext.current_context() def z(a: int, b: str) -> typing.Tuple[int, str]: ... our_interface = transform_signature_to_interface(inspect.signature(z)) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].required assert params.parameters["a"].default is None assert params.parameters["b"].required assert params.parameters["b"].default is None def z(a: int, b: str = "hello") -> typing.Tuple[int, str]: ... our_interface = transform_signature_to_interface(inspect.signature(z)) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].required assert params.parameters["a"].default is None assert not params.parameters["b"].required assert params.parameters[ "b"].default.scalar.primitive.string_value == "hello" def z(a: int = 7, b: str = "eleven") -> typing.Tuple[int, str]: ... our_interface = transform_signature_to_interface(inspect.signature(z)) params = transform_inputs_to_parameters(ctx, our_interface) assert not params.parameters["a"].required assert params.parameters["a"].default.scalar.primitive.integer == 7 assert not params.parameters["b"].required assert params.parameters[ "b"].default.scalar.primitive.string_value == "eleven"
def create( cls, name: str, workflow: _annotated_workflow.WorkflowBase, default_inputs: Dict[str, Any] = None, fixed_inputs: Dict[str, Any] = None, schedule: _schedule_model.Schedule = None, notifications: List[_common_models.Notification] = None, auth_role: _common_models.AuthRole = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() default_inputs = default_inputs or {} fixed_inputs = fixed_inputs or {} # Default inputs come from two places, the original signature of the workflow function, and the default_inputs # argument to this function. We'll take the latter as having higher precedence. wf_signature_parameters = transform_inputs_to_parameters( ctx, workflow.python_interface) # Construct a new Interface object with just the default inputs given to get Parameters, maybe there's an # easier way to do this, think about it later. temp_inputs = {} for k, v in default_inputs.items(): temp_inputs[k] = (workflow.python_interface.inputs[k], v) temp_interface = Interface(inputs=temp_inputs, outputs={}) temp_signature = transform_inputs_to_parameters(ctx, temp_interface) wf_signature_parameters._parameters.update(temp_signature.parameters) # These are fixed inputs that cannot change at launch time. If the same argument is also in default inputs, # it'll be taken out from defaults in the LaunchPlan constructor fixed_literals = translate_inputs_to_literals( ctx, incoming_values=fixed_inputs, flyte_interface_types=workflow.interface.inputs, native_types=workflow.python_interface.inputs, ) fixed_lm = _literal_models.LiteralMap(literals=fixed_literals) lp = cls( name=name, workflow=workflow, parameters=wf_signature_parameters, fixed_inputs=fixed_lm, schedule=schedule, notifications=notifications, auth_role=auth_role, ) # This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out # to protobuf, but for local execution and such, why not save the original Python native values as well so # we don't have to reverse it back every time. default_inputs.update(fixed_inputs) lp._saved_inputs = default_inputs if name in cls.CACHE: raise AssertionError( f"Launch plan named {name} was already created! Make sure your names are unique." ) cls.CACHE[name] = lp return lp
def get_default_launch_plan( ctx: FlyteContext, workflow: _annotated_workflow.WorkflowBase) -> LaunchPlan: """ Users should probably call the get_or_create function defined below instead. A default launch plan is the one that will just pick up whatever default values are defined in the workflow function signature (if any) and use the default auth information supplied during serialization, with no notifications or schedules. :param ctx: This is not flytekit.current_context(). This is an internal context object. Users familiar with flytekit should feel free to use this however. :param workflow: The workflow to create a launch plan for. """ if workflow.name in LaunchPlan.CACHE: return LaunchPlan.CACHE[workflow.name] parameter_map = transform_inputs_to_parameters( ctx, workflow.python_interface) lp = LaunchPlan( name=workflow.name, workflow=workflow, parameters=parameter_map, fixed_inputs=_literal_models.LiteralMap(literals={}), ) LaunchPlan.CACHE[workflow.name] = lp return lp
def test_structured_dataset(): ctx = context_manager.FlyteContext.current_context() def z( a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"]: return a our_interface = transform_function_to_interface(z) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].required assert params.parameters["a"].default is None assert our_interface.inputs == {"a": Annotated[int, "some annotation"]} assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]}
def get_default_launch_plan(ctx: FlyteContext, workflow: _annotated_workflow.Workflow) -> LaunchPlan: if workflow.name in LaunchPlan.CACHE: return LaunchPlan.CACHE[workflow.name] parameter_map = transform_inputs_to_parameters(ctx, workflow._native_interface) lp = LaunchPlan( name=workflow.name, workflow=workflow, parameters=parameter_map, fixed_inputs=_literal_models.LiteralMap(literals={}), ) LaunchPlan.CACHE[workflow.name] = lp return lp
def test_parameter_change_to_pickle_type(): ctx = context_manager.FlyteContext.current_context() class Foo: def __init__(self, name): self.name = name def z(a: Foo) -> Foo: ... our_interface = transform_function_to_interface(z) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].required assert params.parameters["a"].default is None assert our_interface.outputs["o0"].__origin__ == FlytePickle assert our_interface.inputs["a"].__origin__ == FlytePickle
def test_parameters_with_docstring(): ctx = context_manager.FlyteContext.current_context() def z(a: int, b: str) -> typing.Tuple[int, str]: """ function z :param a: foo :param b: bar :return: ramen """ ... our_interface = transform_function_to_interface(z, Docstring(callable_=z)) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].var.description == "foo" assert params.parameters["b"].var.description == "bar"
def compile(self, **kwargs): """ Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics a 'closure' in the traditional sense of the word. """ ctx = FlyteContextManager.current_context() self._input_parameters = transform_inputs_to_parameters( ctx, self.python_interface) all_nodes = [] prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else "" with FlyteContextManager.with_context( ctx.with_compilation_state( CompilationState(prefix=prefix, task_resolver=self))) as comp_ctx: # Construct the default input promise bindings, but then override with the provided inputs, if any input_kwargs = construct_input_promises( [k for k in self.interface.inputs.keys()]) input_kwargs.update(kwargs) workflow_outputs = exception_scopes.user_entry_point( self._workflow_function)(**input_kwargs) all_nodes.extend(comp_ctx.compilation_state.nodes) # This little loop was added as part of the task resolver change. The task resolver interface itself is # more or less stateless (the future-proofing get_all_tasks function notwithstanding). However the # implementation of the TaskResolverMixin that this workflow class inherits from (ClassStorageTaskResolver) # does store state. This loop adds Tasks that are defined within the body of the workflow to the workflow # object itself. for n in comp_ctx.compilation_state.nodes: if isinstance(n.flyte_entity, PythonAutoContainerTask ) and n.flyte_entity.task_resolver == self: logger.debug( f"WF {self.name} saving task {n.flyte_entity.name}") self.add(n.flyte_entity) # Iterate through the workflow outputs bindings = [] output_names = list(self.interface.outputs.keys()) # The reason the length 1 case is separate is because the one output might be a list. We don't want to # iterate through the list here, instead we should let the binding creation unwrap it and make a binding # collection/map out of it. if len(output_names) == 1: if isinstance(workflow_outputs, tuple): if len(workflow_outputs) != 1: raise AssertionError( f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}" ) if self.python_interface.output_tuple_name is None: raise AssertionError( "Outputs specification for Workflow does not define a tuple, but return value is a tuple" ) workflow_outputs = workflow_outputs[0] t = self.python_interface.outputs[output_names[0]] b = binding_from_python_std( ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t, ) bindings.append(b) elif len(output_names) > 1: if not isinstance(workflow_outputs, tuple): raise AssertionError( "The Workflow specification indicates multiple return values, received only one" ) if len(output_names) != len(workflow_outputs): raise Exception( f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}" ) for i, out in enumerate(output_names): if isinstance(workflow_outputs[i], ConditionalSection): raise AssertionError( "A Conditional block (if-else) should always end with an `else_()` clause" ) t = self.python_interface.outputs[out] b = binding_from_python_std( ctx, out, self.interface.outputs[out].type, workflow_outputs[i], t, ) bindings.append(b) # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain self._nodes = all_nodes self._output_bindings = bindings if not output_names: return None if len(output_names) == 1: return bindings[0] return tuple(bindings)
def compile(self, **kwargs): """ Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics a 'closure' in the traditional sense of the word. """ ctx = FlyteContext.current_context() self._input_parameters = transform_inputs_to_parameters( ctx, self._native_interface) all_nodes = [] prefix = f"{ctx.compilation_state.prefix}-{self.short_name}-" if ctx.compilation_state is not None else None with ctx.new_compilation_context(prefix=prefix) as comp_ctx: # Construct the default input promise bindings, but then override with the provided inputs, if any input_kwargs = construct_input_promises( [k for k in self.interface.inputs.keys()]) input_kwargs.update(kwargs) workflow_outputs = self._workflow_function(**input_kwargs) all_nodes.extend(comp_ctx.compilation_state.nodes) # Iterate through the workflow outputs bindings = [] output_names = list(self.interface.outputs.keys()) # The reason the length 1 case is separate is because the one output might be a list. We don't want to # iterate through the list here, instead we should let the binding creation unwrap it and make a binding # collection/map out of it. if len(output_names) == 1: if isinstance(workflow_outputs, tuple) and len(workflow_outputs) != 1: raise AssertionError( f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}" ) t = self._native_interface.outputs[output_names[0]] b = binding_from_python_std( ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t, ) bindings.append(b) elif len(output_names) > 1: if not isinstance(workflow_outputs, tuple): raise AssertionError( "The Workflow specification indicates multiple return values, received only one" ) if len(output_names) != len(workflow_outputs): raise Exception( f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}" ) for i, out in enumerate(output_names): if isinstance(workflow_outputs[i], ConditionalSection): raise AssertionError( "A Conditional block (if-else) should always end with an `else_()` clause" ) t = self._native_interface.outputs[out] b = binding_from_python_std( ctx, out, self.interface.outputs[out].type, workflow_outputs[i], t, ) bindings.append(b) # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain self._nodes = all_nodes self._output_bindings = bindings if not output_names: return None if len(output_names) == 1: return bindings[0] return tuple(bindings)
def test_parameters_and_defaults(): ctx = context_manager.FlyteContext.current_context() def z(a: int, b: str) -> typing.Tuple[int, str]: ... our_interface = transform_function_to_interface(z) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].required assert params.parameters["a"].default is None assert params.parameters["b"].required assert params.parameters["b"].default is None def z(a: int, b: str = "hello") -> typing.Tuple[int, str]: ... our_interface = transform_function_to_interface(z) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].required assert params.parameters["a"].default is None assert not params.parameters["b"].required assert params.parameters[ "b"].default.scalar.primitive.string_value == "hello" def z(a: int = 7, b: str = "eleven") -> typing.Tuple[int, str]: ... our_interface = transform_function_to_interface(z) params = transform_inputs_to_parameters(ctx, our_interface) assert not params.parameters["a"].required assert params.parameters["a"].default.scalar.primitive.integer == 7 assert not params.parameters["b"].required assert params.parameters[ "b"].default.scalar.primitive.string_value == "eleven" def z( a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"]: return a our_interface = transform_function_to_interface(z) params = transform_inputs_to_parameters(ctx, our_interface) assert params.parameters["a"].required assert params.parameters["a"].default is None assert our_interface.inputs == {"a": Annotated[int, "some annotation"]} assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]} def z( a: typing.Optional[int] = None, b: typing.Optional[str] = None, c: typing.Union[typing.List[int], None] = None) -> typing.Tuple[int, str]: ... our_interface = transform_function_to_interface(z) params = transform_inputs_to_parameters(ctx, our_interface) assert not params.parameters["a"].required assert params.parameters["a"].default.scalar.none_type == Void() assert not params.parameters["b"].required assert params.parameters["b"].default.scalar.none_type == Void() assert not params.parameters["c"].required assert params.parameters["c"].default.scalar.none_type == Void()