Exemple #1
0
def _listup_functions(inputs: Sequence[T_NODE], outputs: Sequence[T_NODE]):
    input_set = set(inputs)

    def get_prev_nodes(node: T_NODE) -> Sequence[T_NODE]:
        # NOTE(Kiikurage):
        # In chainer v1, "Variable" doesn't support "__eq__" method, so "list.__contains__" cannot be used for list of variables.
        # However, "Variable.__hash__" is implemented and "set.__contains__" is available.
        if node in input_set:
            return []

        elif isinstance(node, T_VARIABLE):
            return [] if node.creator is None else [node.creator]

        else:
            return node.inputs

    result = []  # type: List[T_FUNCTION]
    stack = [(node, None)
             for node in outputs]  # type: List[Tuple[T_NODE, T_NODE]]
    dependency_count = {}  # type: Dict[T_NODE, int]

    while len(stack) > 0:
        node_from, node_to = stack.pop()

        if node_from not in dependency_count:
            stack.append((node_from, node_to))

            prev_nodes = get_prev_nodes(node_from)
            dependency_count[node_from] = 0
            for prev_node in prev_nodes:
                if dependency_count.get(prev_node, 1) > 0:
                    dependency_count[node_from] += 1
                    stack.append((prev_node, node_from))

        elif dependency_count[node_from] == 0:
            if isinstance(node_from, T_FUNCTION):
                result.append(node_from)

            if node_to is not None:
                dependency_count[node_to] -= 1

        else:
            raise CyclicGraphError(
                "[ChainerConverter] Cycles are detected, but ChainerConverter cannot convert cyclic graph"
            )

    return result
Exemple #2
0
def _listup_operations(inputs: Sequence[T_NODE], outputs: Sequence[T_NODE]):
    def get_prev_nodes(node: T_NODE) -> Sequence[T_NODE]:
        if node in inputs:
            return []

        elif isinstance(node, tf.Tensor):
            return [node.op]

        else:
            return node.inputs

    result = []  # type: List[tf.Operation]
    stack = [(node, None)
             for node in outputs]  # type: List[Tuple[T_NODE, T_NODE]]
    dependency_count = {}  # type: Dict[T_NODE, int]

    while len(stack) > 0:
        node_from, node_to = stack.pop()

        if node_from not in dependency_count:
            stack.append((node_from, node_to))

            prev_nodes = get_prev_nodes(node_from)
            dependency_count[node_from] = 0
            for prev_node in prev_nodes:
                if dependency_count.get(prev_node, 1) > 0:
                    dependency_count[node_from] += 1
                    stack.append((prev_node, node_from))

        elif dependency_count[node_from] == 0:
            if isinstance(node_from, tf.Operation):
                result.append(node_from)

            if node_to is not None:
                dependency_count[node_to] -= 1

        else:
            console.debug(
                "[TensorFlowConverter] Cycle is detected in computation graph")
            console.debug("cycle starting node:")
            console.debug(node_from)

            raise CyclicGraphError(
                "[TensorFlowConverter] Cycles are detected, but TensorFlowConverter cannot convert cyclic graph"
            )

    return result
Exemple #3
0
def _listup_functions(inputs: Sequence[T_NODE], outputs: Sequence[T_NODE]):
    def get_prev_nodes(node: T_NODE) -> Sequence[T_NODE]:
        if node in inputs:
            return []

        elif isinstance(node, VariableNode):
            return [] if node.creator is None else [node.creator]

        else:
            return node.inputs

    result = []  # type: List[Function]
    stack = [(node, None) for node in outputs]  # type: List[Tuple[T_NODE, T_NODE]]
    dependency_count = {}  # type: Dict[T_NODE, int]

    while len(stack) > 0:
        node_from, node_to = stack.pop()

        if node_from not in dependency_count:
            stack.append((node_from, node_to))

            prev_nodes = get_prev_nodes(node_from)
            dependency_count[node_from] = 0
            for prev_node in prev_nodes:
                if dependency_count.get(prev_node, 1) > 0:
                    dependency_count[node_from] += 1
                    stack.append((prev_node, node_from))

        elif dependency_count[node_from] == 0:
            if isinstance(node_from, Function):
                result.append(node_from)

            if node_to is not None:
                dependency_count[node_to] -= 1

        else:
            raise CyclicGraphError("[ChainerConverter] Cycles are detected, but ChainerConverter cannot convert cyclic graph")

    return result
Exemple #4
0
def _listup_functions(graph: IGraphProto) -> Sequence[INodeProto]:
    class Container:
        """
        Proto object is not hashable. this container supports hash operation with proto object.
        """
        def __init__(self, proto: INodeProto):
            self.proto = proto

        def __hash__(self):
            return hash(self.proto.name)

        def __eq__(self, other):
            return isinstance(other, Container) and self.proto == other.proto

    creator_map = {}
    for proto in graph.node:
        for name in proto.output:
            creator_map[name] = Container(proto)

    def get_prev_nodes(
            node: Union[Container, str]) -> Sequence[Union[Container, str]]:
        nonlocal creator_map
        if node in graph.input:
            return []

        elif isinstance(node, Container):
            return node.proto.input

        else:
            return [] if node not in creator_map else [creator_map[node]]

    result = []  # type: List[Container]
    stack = [
        (node.name, None) for node in graph.output
    ]  # type: List[Tuple[Union[Container, str], Union[Container, str]]]
    dependency_count = {}  # type: Dict[Union[Container, str], int]

    while len(stack) > 0:
        node_from, node_to = stack.pop()

        if node_from not in dependency_count:
            stack.append((node_from, node_to))

            prev_nodes = get_prev_nodes(node_from)
            dependency_count[node_from] = 0
            for prev_node in prev_nodes:
                if dependency_count.get(prev_node, 1) > 0:
                    dependency_count[node_from] += 1
                    stack.append((prev_node, node_from))

        elif dependency_count[node_from] == 0:
            if isinstance(node_from, Container):
                result.append(node_from)

            if node_to is not None:
                dependency_count[node_to] -= 1

        else:
            raise CyclicGraphError("[ONNXConverter] Cycles are detected")

    return [r.proto for r in result]