Example #1
0
 def end_branch(self) -> Union[Condition, Promise]:
     """
     This should be invoked after every branch has been visited
     """
     if self._last_case:
         FlyteContextManager.pop_context()
     return self._condition
Example #2
0
 def end_branch(
     self
 ) -> Optional[Union[Condition, Tuple[Promise], Promise, VoidPromise]]:
     """
     This should be invoked after every branch has been visited
     """
     if self._last_case:
         FlyteContextManager.pop_context()
         curr = self.compute_output_vars()
         if curr is None:
             return VoidPromise(self.name)
         promises = [Promise(var=x, val=None) for x in curr]
         return create_task_output(promises)
     return self._condition
Example #3
0
    def end_branch(
        self
    ) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidPromise]]:
        """
        This should be invoked after every branch has been visited.
        In case this is not local workflow execution then, we should check if this is the last case.
        If so then return the promise, else return the condition
        """
        if self._last_case:
            # We have completed the conditional section, lets pop off the branch context
            FlyteContextManager.pop_context()
            ctx = FlyteContextManager.current_context()
            # Question: This is commented out because we don't need it? Nodes created in the conditional
            #   compilation state are captured in the to_case_block? Always?
            #   Is this still true of nested conditionals? Is that why propeller compiler is complaining?
            # branch_nodes = ctx.compilation_state.nodes
            node, promises = to_branch_node(self._name, self)
            # Verify branch_nodes == nodes in bn
            bindings: typing.List[Binding] = []
            upstream_nodes = set()
            for p in promises:
                if not p.is_ready:
                    bindings.append(
                        Binding(var=p.var, binding=BindingData(promise=p.ref)))
                    upstream_nodes.add(p.ref.node)

            n = Node(
                id=
                f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",  # type: ignore
                metadata=_core_wf.NodeMetadata(self._name,
                                               timeout=datetime.timedelta(),
                                               retries=RetryStrategy(0)),
                bindings=sorted(bindings, key=lambda b: b.var),
                upstream_nodes=list(upstream_nodes),  # type: ignore
                flyte_entity=node,
            )
            FlyteContextManager.current_context().compilation_state.add_node(
                n)  # type: ignore
            return self._compute_outputs(n)
        return self._condition
Example #4
0
 def end_branch(self) -> Union[Condition, Promise]:
     """
     This should be invoked after every branch has been visited
     In case of Local workflow execution, we should first mark the branch as complete, then
     Then we first check for if this is the last case,
     In case this is the last case, we return the output from the selected case - A case should always
     be selected (see start_branch)
     If this is not the last case, we should return the condition so that further chaining can be done
     """
     ctx = FlyteContextManager.current_context()
     # Let us mark the execution state as complete
     ctx.execution_state.branch_complete()
     if self._last_case:
         # We have completed the conditional section, lets pop off the branch context
         FlyteContextManager.pop_context()
         if self._selected_case.output_promise is None and self._selected_case.err is None:
             raise AssertionError(
                 "Bad conditional statements, did not resolve in a promise")
         elif self._selected_case.output_promise is not None:
             return self._selected_case.output_promise
         raise ValueError(self._selected_case.err)
     return self._condition
Example #5
0
    def end_branch(self) -> Union[Condition, Promise]:
        """
        This should be invoked after every branch has been visited
        """
        ctx = FlyteContextManager.current_context()
        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            """
            In case of Local workflow execution, we should first mark the branch as complete, then
            Then we first check for if this is the last case,
            In case this is the last case, we return the output from the selected case - A case should always
            be selected (see start_branch)
            If this is not the last case, we should return the condition so that further chaining can be done
            """
            # Let us mark the execution state as complete
            ctx.execution_state.branch_complete()
            if self._last_case:
                # We have completed the conditional section, lets pop off the branch context
                FlyteContextManager.pop_context()
                if self._selected_case.output_promise is None and self._selected_case.err is None:
                    raise AssertionError(
                        "Bad conditional statements, did not resolve in a promise"
                    )
                elif self._selected_case.output_promise is not None:
                    return self._selected_case.output_promise
                raise ValueError(self._selected_case.err)
            return self._condition
        elif ctx.compilation_state:
            ########
            # COMPILATION MODE
            """
            In case this is not local workflow execution then, we should check if this is the last case.
            If so then return the promise, else return the condition
            """
            if self._last_case:
                # We have completed the conditional section, lets pop off the branch context
                # branch_nodes = ctx.compilation_state.nodes
                FlyteContextManager.pop_context()
                node, promises = to_branch_node(self._name, self)
                # Verify branch_nodes == nodes in bn
                bindings: typing.List[Binding] = []
                upstream_nodes = set()
                for p in promises:
                    if not p.is_ready:
                        bindings.append(
                            Binding(var=p.var,
                                    binding=BindingData(promise=p.ref)))
                        upstream_nodes.add(p.ref.node)

                n = Node(
                    id=
                    f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}",
                    metadata=_core_wf.NodeMetadata(
                        self._name,
                        timeout=datetime.timedelta(),
                        retries=RetryStrategy(0)),
                    bindings=sorted(bindings, key=lambda b: b.var),
                    upstream_nodes=list(upstream_nodes),  # type: ignore
                    flyte_entity=node,
                )
                FlyteContextManager.current_context(
                ).compilation_state.add_node(n)
                return self._compute_outputs(n)
            return self._condition

        raise AssertionError(
            "Branches can only be invoked within a workflow context!")