示例#1
0
    def _apply_and_replace_all_child_nodes(
            self, fn: "Callable[[DAGNode], T]") -> "DAGNode":
        """Apply and replace all immediate child nodes using a given function.

        This is a shallow replacement only. To recursively transform nodes in
        the DAG, use ``_apply_recursive()``.

        Args:
            fn: Callable that will be applied once to each child of this node.

        Returns:
            New DAGNode after replacing all child nodes.
        """

        replace_table = {}
        # CloudPickler scanner object for current layer of DAGNode. Same
        # scanner should be use for a full find & replace cycle.
        scanner = _PyObjScanner()
        # Find all first-level nested DAGNode children in args.
        # Update replacement table and execute the replace.
        for node in scanner.find_nodes([self._bound_args, self._bound_kwargs]):
            if node not in replace_table:
                replace_table[node] = fn(node)
        new_args, new_kwargs = scanner.replace_nodes(replace_table)

        # Return updated copy of self.
        return self._copy(new_args, new_kwargs, self.get_options())
示例#2
0
    def _apply_functional(
        self,
        source_input_list: Any,
        predictate_fn: Callable,
        apply_fn: Callable,
    ):
        """
        Apply a given function to DAGNodes in source_input_list, and return
        the replaced inputs without mutating or coping any DAGNode.

        Args:
            source_input_list: Source inputs to extract and apply function on
                all children DAGNode instances.
            predictate_fn: Applied on each DAGNode instance found and determine
                if we should apply function to it. Can be used to filter node
                types.
            apply_fn: Function to appy on the node on bound attributes. Example:
                apply_fn = lambda node: node._get_serve_deployment_handle(
                    node._deployment, node._bound_other_args_to_resolve
                )

        Returns:
            replaced_inputs: Outputs of apply_fn on DAGNodes in
                source_input_list that passes predictate_fn.
        """
        replace_table = {}
        scanner = _PyObjScanner()
        for node in scanner.find_nodes(source_input_list):
            if predictate_fn(node) and node not in replace_table:
                replace_table[node] = apply_fn(node)

        replaced_inputs = scanner.replace_nodes(replace_table)

        return replaced_inputs
示例#3
0
    def _get_all_child_nodes(self) -> Set["DAGNode"]:
        """Return the set of nodes referenced by the args of this node.

        For example, in `f.remote(a, [b])`, this includes both `a` and `b`.

        Args:
            f (_PyObjScanner): A CloudPickler scanner object that finds and
                keeps track of DAGNode objects in current shallow layer for
                replacement later on.
        """

        scanner = _PyObjScanner()
        children = set()
        for n in scanner.find_nodes([self._bound_args, self._bound_kwargs]):
            children.add(n)
        return children
示例#4
0
    def _get_all_child_nodes(self) -> Set["DAGNode"]:
        """Return the set of nodes referenced by the args, kwargs, and
        args_to_resolve in current node, even they're deeply nested.

        Examples:
            f.remote(a, [b]) -> set(a, b)
            f.remote(a, [b], key={"nested": [c]}) -> set(a, b, c)
        """

        scanner = _PyObjScanner()
        children = set()
        for n in scanner.find_nodes([
                self._bound_args,
                self._bound_kwargs,
                self._bound_other_args_to_resolve,
        ]):
            children.add(n)
        return children
示例#5
0
    def _get_all_child_nodes(self) -> List["DAGNode"]:
        """Return the list of nodes referenced by the args, kwargs, and
        args_to_resolve in current node, even they're deeply nested.

        Examples:
            f.remote(a, [b]) -> [a, b]
            f.remote(a, [b], key={"nested": [c]}) -> [a, b, c]
        """

        scanner = _PyObjScanner()
        # we use List instead of Set here, reason explained
        # in `_get_toplevel_child_nodes`.
        children = []
        for n in scanner.find_nodes([
                self._bound_args,
                self._bound_kwargs,
                self._bound_other_args_to_resolve,
        ]):
            if n not in children:
                children.append(n)
        return children