def test_comparison_refs(): def dummy_node(node_id) -> Node: n = Node( node_id, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=SQLTask(name="x", query_template="x", inputs={}), ) n._id = node_id return n px = Promise("x", NodeOutput(var="x", node=dummy_node("n1"))) py = Promise("y", NodeOutput(var="y", node=dummy_node("n2"))) def print_expr(expr): print(f"{expr} is type {type(expr)}") print_expr(px == py) print_expr(px < py) print_expr((px == py) & (px < py)) print_expr(((px == py) & (px < py)) | (px > py)) print_expr(px < 5) print_expr(px >= 5)
def test_comparison_lits(): px = Promise("x", TypeEngine.to_literal(None, 5, int, None)) py = Promise("y", TypeEngine.to_literal(None, 8, int, None)) def eval_expr(expr, expected: bool): print(f"{expr} evals to {expr.eval()}") assert expected == expr.eval() eval_expr(px == py, False) eval_expr(px < py, True) eval_expr((px == py) & (px < py), False) eval_expr(((px == py) & (px < py)) | (px > py), False) eval_expr(px < 5, False) eval_expr(px >= 5, True) eval_expr(py >= 5, True)
def _workflow_fn_outputs_to_promise( ctx: FlyteContext, native_outputs: typing.Dict[str, type], # Actually an orderedDict typed_outputs: Dict[str, _interface_models.Variable], outputs: Union[Any, Tuple[Any]], ) -> List[Promise]: if len(native_outputs) == 1: if isinstance(outputs, tuple): if len(outputs) != 1: raise AssertionError( f"The Workflow specification indicates only one return value, received {len(outputs)}" ) else: outputs = (outputs,) if len(native_outputs) > 1: if not isinstance(outputs, tuple) or len(native_outputs) != len(outputs): # Length check, clean up exception raise AssertionError( f"The workflow specification indicates {len(native_outputs)} return vals, but received {len(outputs)}" ) # This recasts the Promises provided by the outputs of the workflow's tasks into the correct output names # of the workflow itself return_vals = [] for (k, t), v in zip(native_outputs.items(), outputs): if isinstance(v, Promise): return_vals.append(v.with_var(k)) else: # Found a return type that is not a promise, so we need to transform it var = typed_outputs[k] return_vals.append(Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, var.type))) return return_vals
def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: """ Please see the _local_execute comments in the main task. """ # Unwrap the kwargs values. After this, we essentially have a LiteralMap # The reason why we need to do this is because the inputs during local execute can be of 2 types # - Promises or native constants # Promises as essentially inputs from previous task executions # native constants are just bound to this specific task (default values for a task input) # Also alongwith promises and constants, there could be dictionary or list of promises or constants kwargs = translate_inputs_to_literals( ctx, input_kwargs=kwargs, interface=self.typed_interface, native_input_types=self.interface.inputs ) input_literal_map = _literal_models.LiteralMap(literals=kwargs) outputs_literal_map = self.unwrap_literal_map_and_execute(ctx, input_literal_map) # After running, we again have to wrap the outputs, if any, back into Promise objects outputs_literals = outputs_literal_map.literals output_names = list(self.interface.outputs.keys()) if len(output_names) != len(outputs_literals): # Length check, clean up exception raise AssertionError(f"Length difference {len(output_names)} {len(outputs_literals)}") # Tasks that don't return anything still return a VoidPromise if len(output_names) == 0: return VoidPromise(self.name) vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.interface)
def _compute_outputs(self, n: Node) -> Union[Promise, Tuple[Promise], VoidPromise]: output_var_sets: typing.List[typing.Set[str]] = [] for c in self._cases: if c.output_promise is None and c.err is None: # One node returns a void output and no error, we will default to a # Void output return VoidPromise(n.id) if c.output_promise is not None: if isinstance(c.output_promise, tuple): output_var_sets.append(set([i.var for i in c.output_promise])) else: output_var_sets.append(set([c.output_promise.var])) curr = output_var_sets[0] if len(output_var_sets) > 1: for x in output_var_sets[1:]: curr = curr.intersection(x) promises = [Promise(var=x, val=NodeOutput(node=n, var=x)) for x in curr] # TODO: Is there a way to add the Python interface here? Currently, it's an optional arg. return create_task_output(promises)
def _local_execute( self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: """ This code is used only in the case when we want to dispatch_execute with outputs from a previous node For regular execution, dispatch_execute is invoked directly. """ # Unwrap the kwargs values. After this, we essentially have a LiteralMap # The reason why we need to do this is because the inputs during local execute can be of 2 types # - Promises or native constants # Promises as essentially inputs from previous task executions # native constants are just bound to this specific task (default values for a task input) # Also alongwith promises and constants, there could be dictionary or list of promises or constants kwargs = translate_inputs_to_literals( ctx, input_kwargs=kwargs, interface=self.interface, native_input_types=self.get_input_types()) input_literal_map = _literal_models.LiteralMap(literals=kwargs) outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special # location, otherwise we dont really need to right? The higher level execute could just handle literalMap # After running, we again have to wrap the outputs, if any, back into Promise objects output_names = list(self.interface.outputs.keys()) if len(output_names) != len(outputs_literals): # Length check, clean up exception raise AssertionError( f"Length difference {len(output_names)} {len(outputs_literals)}" ) # Tasks that don't return anything still return a VoidPromise if len(output_names) == 0: return VoidPromise(self.name) vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface)
def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: """ Performs local execution of a workflow. kwargs are expected to be Promises for the most part (unless, someone has hardcoded in my_wf(input_1=5) or something). :param ctx: The FlyteContext :param kwargs: parameters for the workflow itself """ logger.info(f"Executing Workflow {self._name}, ctx{ctx.execution_state.Mode}") # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. for k, v in kwargs.items(): if not isinstance(v, Promise): t = self._native_interface.inputs[k] kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type)) function_outputs = self.execute(**kwargs) if ( isinstance(function_outputs, VoidPromise) or function_outputs is None or len(self.python_interface.outputs) == 0 ): # The reason this is here is because a workflow function may return a task that doesn't return anything # def wf(): # return t1() # or it may not return at all # def wf(): # t1() # In the former case we get the task's VoidPromise, in the latter we get None return VoidPromise(self.name) # TODO: Can we refactor the task code to be similar to what's in this function? promises = _workflow_fn_outputs_to_promise( ctx, self._native_interface.outputs, self.interface.outputs, function_outputs ) # TODO: With the native interface, create_task_output should be able to derive the typed interface, and it # should be able to do the conversion of the output of the execute() call directly. return create_task_output(promises, self._native_interface)
def construct_input_promises(inputs: List[str]): return { input_name: Promise(var=input_name, val=NodeOutput(node=GLOBAL_START_NODE, var=input_name)) for input_name in inputs }