def _listup_functions(inputs: Sequence[T_NODE], outputs: Sequence[T_NODE]): input_set = set(inputs) def get_prev_nodes(node: T_NODE) -> Sequence[T_NODE]: # NOTE(Kiikurage): # In chainer v1, "Variable" doesn't support "__eq__" method, so "list.__contains__" cannot be used for list of variables. # However, "Variable.__hash__" is implemented and "set.__contains__" is available. if node in input_set: return [] elif isinstance(node, T_VARIABLE): return [] if node.creator is None else [node.creator] else: return node.inputs result = [] # type: List[T_FUNCTION] stack = [(node, None) for node in outputs] # type: List[Tuple[T_NODE, T_NODE]] dependency_count = {} # type: Dict[T_NODE, int] while len(stack) > 0: node_from, node_to = stack.pop() if node_from not in dependency_count: stack.append((node_from, node_to)) prev_nodes = get_prev_nodes(node_from) dependency_count[node_from] = 0 for prev_node in prev_nodes: if dependency_count.get(prev_node, 1) > 0: dependency_count[node_from] += 1 stack.append((prev_node, node_from)) elif dependency_count[node_from] == 0: if isinstance(node_from, T_FUNCTION): result.append(node_from) if node_to is not None: dependency_count[node_to] -= 1 else: raise CyclicGraphError( "[ChainerConverter] Cycles are detected, but ChainerConverter cannot convert cyclic graph" ) return result
def _listup_operations(inputs: Sequence[T_NODE], outputs: Sequence[T_NODE]): def get_prev_nodes(node: T_NODE) -> Sequence[T_NODE]: if node in inputs: return [] elif isinstance(node, tf.Tensor): return [node.op] else: return node.inputs result = [] # type: List[tf.Operation] stack = [(node, None) for node in outputs] # type: List[Tuple[T_NODE, T_NODE]] dependency_count = {} # type: Dict[T_NODE, int] while len(stack) > 0: node_from, node_to = stack.pop() if node_from not in dependency_count: stack.append((node_from, node_to)) prev_nodes = get_prev_nodes(node_from) dependency_count[node_from] = 0 for prev_node in prev_nodes: if dependency_count.get(prev_node, 1) > 0: dependency_count[node_from] += 1 stack.append((prev_node, node_from)) elif dependency_count[node_from] == 0: if isinstance(node_from, tf.Operation): result.append(node_from) if node_to is not None: dependency_count[node_to] -= 1 else: console.debug( "[TensorFlowConverter] Cycle is detected in computation graph") console.debug("cycle starting node:") console.debug(node_from) raise CyclicGraphError( "[TensorFlowConverter] Cycles are detected, but TensorFlowConverter cannot convert cyclic graph" ) return result
def _listup_functions(inputs: Sequence[T_NODE], outputs: Sequence[T_NODE]): def get_prev_nodes(node: T_NODE) -> Sequence[T_NODE]: if node in inputs: return [] elif isinstance(node, VariableNode): return [] if node.creator is None else [node.creator] else: return node.inputs result = [] # type: List[Function] stack = [(node, None) for node in outputs] # type: List[Tuple[T_NODE, T_NODE]] dependency_count = {} # type: Dict[T_NODE, int] while len(stack) > 0: node_from, node_to = stack.pop() if node_from not in dependency_count: stack.append((node_from, node_to)) prev_nodes = get_prev_nodes(node_from) dependency_count[node_from] = 0 for prev_node in prev_nodes: if dependency_count.get(prev_node, 1) > 0: dependency_count[node_from] += 1 stack.append((prev_node, node_from)) elif dependency_count[node_from] == 0: if isinstance(node_from, Function): result.append(node_from) if node_to is not None: dependency_count[node_to] -= 1 else: raise CyclicGraphError("[ChainerConverter] Cycles are detected, but ChainerConverter cannot convert cyclic graph") return result
def _listup_functions(graph: IGraphProto) -> Sequence[INodeProto]: class Container: """ Proto object is not hashable. this container supports hash operation with proto object. """ def __init__(self, proto: INodeProto): self.proto = proto def __hash__(self): return hash(self.proto.name) def __eq__(self, other): return isinstance(other, Container) and self.proto == other.proto creator_map = {} for proto in graph.node: for name in proto.output: creator_map[name] = Container(proto) def get_prev_nodes( node: Union[Container, str]) -> Sequence[Union[Container, str]]: nonlocal creator_map if node in graph.input: return [] elif isinstance(node, Container): return node.proto.input else: return [] if node not in creator_map else [creator_map[node]] result = [] # type: List[Container] stack = [ (node.name, None) for node in graph.output ] # type: List[Tuple[Union[Container, str], Union[Container, str]]] dependency_count = {} # type: Dict[Union[Container, str], int] while len(stack) > 0: node_from, node_to = stack.pop() if node_from not in dependency_count: stack.append((node_from, node_to)) prev_nodes = get_prev_nodes(node_from) dependency_count[node_from] = 0 for prev_node in prev_nodes: if dependency_count.get(prev_node, 1) > 0: dependency_count[node_from] += 1 stack.append((prev_node, node_from)) elif dependency_count[node_from] == 0: if isinstance(node_from, Container): result.append(node_from) if node_to is not None: dependency_count[node_to] -= 1 else: raise CyclicGraphError("[ONNXConverter] Cycles are detected") return [r.proto for r in result]