예제 #1
0
    def call(self, command: str, inputs: Dict) -> Dict:
        """
        Will be invoked by a node when executing a graph.

        Subclass of NodeContainable may override the call() function:
            import h1st as h1

            class MyClass(h1.NodeContainable)
                def call(self, command, inputs):
                    ...


        Or subclass may implement necessary functions which required for graph execution flows
        During executing, the function with name = value of command will be invoked
            class MyClass(h1.NodeContainable)                
                def predict(self, inputs):
                    # this function will be invoked if a graph execution is for predict:
                    # graph.predict(...)
                    # or graph.execute(command='predict', input_data=...)
                    ...
           
        :param command: to know which graph's execution flow (predict, train, ...) it is involving
        :inputs: input data to proceed accordingly to the flow

        :return: result as a dict
        """
        func = getattr(self, command)
        if not func:
            raise GraphException(f'class {self.__class__.__name__} must implement method "{command}')

        result = func(inputs)
        if not isinstance(result, dict):
            raise GraphException(f'output of {self.__class__.__name__} must be a dict')

        return result
예제 #2
0
    def _add_and_connect(
            self,
            node: Union[Node, NodeContainable, None] = None,
            yes: Union[Node, NodeContainable, None] = None,
            no: Union[Node, NodeContainable, None] = None,
            id: str = None,
            from_: Union[Node, None] = None) -> Union[Node, List[Node]]:
        """
        Adds node/yes/no nodes to self.nodes and connect from_node to newly added nodes
        """
        if id == 'start' and hasattr(self.nodes, 'start'):
            raise GraphException('Graph.start() may only be called once')

        if id == 'end' and hasattr(self.nodes, 'end'):
            raise GraphException('Graph.end() may only be called once')

        if hasattr(self.nodes, 'end'):
            raise GraphException('not allow to add a node after Graph.end()')

        if id and hasattr(self.nodes, id):
            raise GraphException(f'Node id={id} is duplicated')

        # connect to latest node if not adding from a node
        if not from_:
            from_ = self._last_added_node

        # add a single node
        if node:
            node = self._wrap_and_add(node, id)
            self._connect_nodes(from_, node)

            # keep reference to the latest node
            self._last_added_node = node

            return node

        # add nodes with edge_label 'yes' / 'no'
        return_nodes = []

        if yes:
            node = self._wrap_and_add(yes)
            self._connect_nodes(from_, node, 'yes')
            return_nodes.append(node)

        if no:
            node = self._wrap_and_add(no)
            self._connect_nodes(from_, node, 'no')
            return_nodes.append(node)

        # keep reference to the latest node
        self._last_added_node = return_nodes[0] if len(
            return_nodes) == 1 else None

        # chaining will return array if having more than one node, otherwise return single node
        return return_nodes[0] if len(return_nodes) == 1 else return_nodes
예제 #3
0
    def _wrap_and_add(self,
                      node: Union[Node, NodeContainable],
                      id: str = None):
        """
        Wraps NodeContainable to a Node if needed. Override node's id if an id provided. Then adds the node to the graph.

        :param node: a Node or ContainableNode
        :param id: id for the node

        :return: the newly added node
        """
        if not isinstance(node, (Node, NodeContainable)):
            raise GraphException(
                'object to add to a graph must be an instance of Node or NodeContainable'
            )

        if isinstance(node, NodeContainable):
            containable = node
            node = Action(containable)

        id = id or node.id
        if id:  # manual provided id
            self._used_node_ids[id] = 0
        else:  # automatic id
            id = self._generate_id(node)
            self._used_node_ids[id] = 1

        node._id = id
        node.graph = self
        setattr(self.nodes, id, node)

        return node
예제 #4
0
파일: node.py 프로젝트: nguyenduyphuc/h1st
    def _execute(self, command: Optional[str], inputs: Dict[str, Any]) -> Dict:
        """
        super._execute() will be responsible for executing the node.
        This will ensure the result's structure is valid for decision node.

        :returns:
            a dictionary containing 'results' key and each item will have a field whose name = 'prediction'
            with bool value to decide whether the item belongs to yes or no branch
                { 
                    'results': [{ 'prediction': True/False, ...}],
                    'other_key': ...,
                }

            or a dictionary containing only one key
                {
                    'your_key': [{ 'prediction': True/False, ...}]
                }
        """
        result = super()._execute(command, inputs)

        if not isinstance(result, dict) or ((self._result_field not in result)
                                            and len(result.keys()) != 1):
            raise GraphException(
                f'output of {self._containable.__class__.__name__} must be a dict containing "results" field or only one key'
            )

        return result
예제 #5
0
    def _connect_nodes(self,
                       from_: Node,
                       to: Node,
                       edge_label=None) -> NoReturn:
        """
        Connects "from_" node to the "to" node

        :param from_: the source node
        :param to: the destination node
        :edge_label: the label for the edge between from_ and to
        """
        if not from_:
            return

        if edge_label not in ['yes', 'no', None]:
            raise GraphException(f'edge_label="{edge_label}" is not supported')

        from_.edges.append((to, edge_label))
예제 #6
0
    def _validate_output(self, node_output) -> bool:
        """
        This will ensure the result's structure is valid for decision node.
        
        node_output must be a dictionary containing 'results' key and each item will have a field whose name = 'prediction'
        with bool value to decide whether the item belongs to yes or no branch
            { 
                'results': [{ 'prediction': True/False, ...}],
                'other_key': ...,
            }

        or a dictionary containing only one key
            {
                'your_key': [{ 'prediction': True/False, ...}]
            }
        """
        if not isinstance(node_output, dict) or (
            (self._result_field not in node_output)
                and len(node_output.keys()) != 1):
            raise GraphException(
                f'output of {type(self._containable)} must be a dict containing "results" field or only one key'
            )

        return True
예제 #7
0
    def graph(self, value):
        if self.graph:
            raise GraphException('This node belongs to another graph already')

        self._graph = value