def test_toposort_overlap(): a = (1, 2) b = (a, 2) c = (b, a, 2) order = list(toposort(c, _succ, _incl)) _check_toposort(order, c, _succ, _incl)
def test_toposort(): a = (1, 2) b = (3, 4, 5) c = (b, a, 6) d = (a, c) order = list(toposort(d, _succ, _incl)) _check_toposort(order, d, _succ, _incl)
def test_toposort_bad_include(): a = (1, 2) b = (3, 4, 5) c = (b, a, 6) d = (a, c) def inc(n): return None with pytest.raises(ValueError): list(toposort(d, _succ, inc))
def test_toposort_incl(): def _incl_nf(x): if isinstance(x, tuple) and len(x) == 2: return NOFOLLOW else: return FOLLOW def _incl_x(x): if isinstance(x, tuple) and len(x) == 2: return EXCLUDE else: return FOLLOW a = (1, 2) b = (a, 3) c = (a, b, 4, 5) order1 = list(toposort(c, _succ, _incl_nf)) _check_toposort(order1, c, _succ, _incl_nf) order2 = list(toposort(c, _succ, _incl_x)) _check_toposort(order2, c, _succ, _incl_x)
def test_toposort_cycle(): class Q: def __init__(self, x, y): self.x = x self.y = y def qsucc(q): return [q.x, q.y] q = Q(1, 2) q.y = q with pytest.raises(ValueError): list(toposort(q, qsucc, _incl))
def convert_func(self, graph): """Convert a graph.""" for p in graph.parameters: self.node_map[p] = self.on_parameter(p) params = [self.ref(p) for p in graph.parameters] seq = [] for node in toposort(graph.output, NodeVisitor(), in_graph(graph)): if node in self.node_map: continue elif node.is_constant_graph() and node.value.parent is None: self.node_map[node] = self.graph_map[node.value] else: self.node_map[node] = relay.var(f"seq.{self.i}") self.i += 1 seq.append(node) out = self.ref(graph.output) for op in reversed(seq): var = self.node_map[op] if op.is_apply(): val = self.on_apply(op) elif op.is_constant_graph(): val = self.convert_func(op.value) elif op.is_constant(): val = self.on_constant(op) # This forces the rebuild of constants every time they # are encountered since they may be shared amongst # multiple graphs and it causes problems otherwise. del self.node_map[op] else: raise AssertionError(f"Bad node for sequence: {op}") out = relay.Let(var, val, out) return relay.Function(params, out, ret_type=to_relay_type(graph.output.abstract))
def compile(self): """Compile graph to a Python function code. Return function parameters names (list) and function body code (nested list). """ graph = self.graph # Register parameters. param_names = [] for p in graph.parameters: self._add_node(p) param_names.append(self.local_ref(p)) seq = [] # Register nodes. for node in toposort(graph.output, NodeVisitor(), in_graph(graph)): if self._add_node(node): seq.append(node) output = [] # Get code for each node. for op in seq: op_name = self.local_ref(op) if op.is_apply(): op_code = self.on_apply(op) elif op.is_constant_graph(): op_code = self.on_function(op.value, op) elif op.is_constant(): op_code = self.on_constant(op) else: raise AssertionError(f"Unsupported node: {op}") # If latest op is graph.output, we can write code # to return it immediately. if op is seq[-1] is graph.output: prefix = "return" else: prefix = f"{op_name} =" # Code for a node may be a list of lines of code. # In such case, latest code line should be only the return value. if isinstance(op_code, list): build_code = op_code[:-1] return_code = op_code[-1] output.extend(build_code) output.append(f"{prefix} {return_code}") elif op_code: output.append(f"{prefix} {op_code}") # Output may be empty, for e.g. if function just returns a parameter. if not output: output.append(f"return {self.ref(graph.output)}") constants_code = [] closures_code = [] for cst_name, cst_val in self.const_name_to_value.items(): if not self._is_inline_const(cst_val): # Write only non-inlinable constants. # Other are inlined, hence it's useless to define them. constants_code.append(f"{cst_name} = {cst_val}") if constants_code: constants_code = ["# Constants"] + constants_code + [""] for fn_name, (fn_params, fn_body) in self.closure_name_to_code.items(): fn_signature = f"def {fn_name}({', '.join(fn_params)}):" closures_code.append(fn_signature) closures_code.append(fn_body) # Body code contains constants, then closures, then function code. return param_names, constants_code + closures_code + output
def compile(self): """Compile graph to a Python function code. Return function parameters names (list) and function body code (nested list). """ graph = self.graph # Register parameters. param_names = [] for p in graph.parameters: # A parameter should always be a local variable. self.node_to_name[p] = self.get_label(p) param_names.append(self.local_ref(p)) seq = [] # Register nodes. for node in toposort( graph.output, NodeVisitor(), in_graph(graph), allow_cycles=True ): if self._add_node(node): seq.append(node) output = [] # Get code for each node. for op in seq: op_name = self.local_ref(op) if op.is_apply(): op_code = self.on_apply(op) elif op.is_constant_graph(): op_code = self.on_function(op.value, op) elif op.is_constant(): op_code = self.on_constant(op) else: raise AssertionError(f"Unsupported node: {op}") # If latest op is graph.output, we can write code # to return it immediately. if op is seq[-1] is graph.output: prefix = "return" else: prefix = f"{op_name} =" # Code for a node may be a list of lines of code. # In such case, latest code line should be only the return value. if isinstance(op_code, list): build_code = op_code[:-1] return_code = op_code[-1] output.extend(build_code) output.append(f"{prefix} {return_code}") elif op_code: output.append(f"{prefix} {op_code}") # Output may be empty, for e.g. if function just returns a parameter. if not output: # I don't know why, but there may be functions that # return a value not even reachable from function code # (thus, self.ref() might not find node). # In such case, let s just put an exception raising in code. # If such function is called, then code will fail. try: output_name = self.ref(graph.output) except KeyError: output.append( f'raise RuntimeError("Unreachable: {type(graph.output).__name__} {graph.output}")' ) else: output.append(f"return {output_name}") constants_code = [] closures_code = [] for cst_name, cst_val in self.const_name_to_value.items(): if not self._is_inline_const(cst_val): # Write only non-inlinable constants. # Other are inlined, hence it's useless to define them. constants_code.append(f"{cst_name} = {cst_val}") if constants_code: constants_code = ["# Constants"] + constants_code + [""] for fn_name, (fn_params, fn_body) in self.closure_name_to_code.items(): fn_signature = f"def {fn_name}({', '.join(fn_params)}):" closures_code.append(fn_signature) closures_code.append(fn_body) # Body code contains constants, then closures, then function code. return param_names, constants_code + closures_code + output