示例#1
0
def test_toposort_overlap():
    a = (1, 2)
    b = (a, 2)
    c = (b, a, 2)

    order = list(toposort(c, _succ, _incl))
    _check_toposort(order, c, _succ, _incl)
示例#2
0
def test_toposort():
    a = (1, 2)
    b = (3, 4, 5)
    c = (b, a, 6)
    d = (a, c)

    order = list(toposort(d, _succ, _incl))
    _check_toposort(order, d, _succ, _incl)
示例#3
0
def test_toposort_bad_include():
    a = (1, 2)
    b = (3, 4, 5)
    c = (b, a, 6)
    d = (a, c)

    def inc(n):
        return None

    with pytest.raises(ValueError):
        list(toposort(d, _succ, inc))
示例#4
0
def test_toposort_incl():
    def _incl_nf(x):
        if isinstance(x, tuple) and len(x) == 2:
            return NOFOLLOW
        else:
            return FOLLOW

    def _incl_x(x):
        if isinstance(x, tuple) and len(x) == 2:
            return EXCLUDE
        else:
            return FOLLOW

    a = (1, 2)
    b = (a, 3)
    c = (a, b, 4, 5)

    order1 = list(toposort(c, _succ, _incl_nf))
    _check_toposort(order1, c, _succ, _incl_nf)

    order2 = list(toposort(c, _succ, _incl_x))
    _check_toposort(order2, c, _succ, _incl_x)
示例#5
0
def test_toposort_cycle():
    class Q:
        def __init__(self, x, y):
            self.x = x
            self.y = y

    def qsucc(q):
        return [q.x, q.y]

    q = Q(1, 2)
    q.y = q

    with pytest.raises(ValueError):
        list(toposort(q, qsucc, _incl))
示例#6
0
    def convert_func(self, graph):
        """Convert a graph."""
        for p in graph.parameters:
            self.node_map[p] = self.on_parameter(p)

        params = [self.ref(p) for p in graph.parameters]

        seq = []
        for node in toposort(graph.output, NodeVisitor(), in_graph(graph)):
            if node in self.node_map:
                continue
            elif node.is_constant_graph() and node.value.parent is None:
                self.node_map[node] = self.graph_map[node.value]
            else:
                self.node_map[node] = relay.var(f"seq.{self.i}")
                self.i += 1
                seq.append(node)

        out = self.ref(graph.output)

        for op in reversed(seq):
            var = self.node_map[op]
            if op.is_apply():
                val = self.on_apply(op)
            elif op.is_constant_graph():
                val = self.convert_func(op.value)
            elif op.is_constant():
                val = self.on_constant(op)
                # This forces the rebuild of constants every time they
                # are encountered since they may be shared amongst
                # multiple graphs and it causes problems otherwise.
                del self.node_map[op]
            else:
                raise AssertionError(f"Bad node for sequence: {op}")
            out = relay.Let(var, val, out)

        return relay.Function(params,
                              out,
                              ret_type=to_relay_type(graph.output.abstract))
示例#7
0
    def compile(self):
        """Compile graph to a Python function code.

        Return function parameters names (list) and function body code (nested list).
        """
        graph = self.graph
        # Register parameters.
        param_names = []
        for p in graph.parameters:
            self._add_node(p)
            param_names.append(self.local_ref(p))
        seq = []
        # Register nodes.
        for node in toposort(graph.output, NodeVisitor(), in_graph(graph)):
            if self._add_node(node):
                seq.append(node)
        output = []
        # Get code for each node.
        for op in seq:
            op_name = self.local_ref(op)
            if op.is_apply():
                op_code = self.on_apply(op)
            elif op.is_constant_graph():
                op_code = self.on_function(op.value, op)
            elif op.is_constant():
                op_code = self.on_constant(op)
            else:
                raise AssertionError(f"Unsupported node: {op}")
            # If latest op is graph.output, we can write code
            # to return it immediately.
            if op is seq[-1] is graph.output:
                prefix = "return"
            else:
                prefix = f"{op_name} ="
            # Code for a node may be a list of lines of code.
            # In such case, latest code line should be only the return value.
            if isinstance(op_code, list):
                build_code = op_code[:-1]
                return_code = op_code[-1]
                output.extend(build_code)
                output.append(f"{prefix} {return_code}")
            elif op_code:
                output.append(f"{prefix} {op_code}")
        # Output may be empty, for e.g. if function just returns a parameter.
        if not output:
            output.append(f"return {self.ref(graph.output)}")

        constants_code = []
        closures_code = []
        for cst_name, cst_val in self.const_name_to_value.items():
            if not self._is_inline_const(cst_val):
                # Write only non-inlinable constants.
                # Other are inlined, hence it's useless to define them.
                constants_code.append(f"{cst_name} = {cst_val}")
        if constants_code:
            constants_code = ["# Constants"] + constants_code + [""]
        for fn_name, (fn_params, fn_body) in self.closure_name_to_code.items():
            fn_signature = f"def {fn_name}({', '.join(fn_params)}):"
            closures_code.append(fn_signature)
            closures_code.append(fn_body)
        # Body code contains constants, then closures, then function code.
        return param_names, constants_code + closures_code + output
示例#8
0
文件: python.py 项目: notoraptor/myia
    def compile(self):
        """Compile graph to a Python function code.

        Return function parameters names (list) and function body code (nested list).
        """
        graph = self.graph
        # Register parameters.
        param_names = []
        for p in graph.parameters:
            # A parameter should always be a local variable.
            self.node_to_name[p] = self.get_label(p)
            param_names.append(self.local_ref(p))
        seq = []
        # Register nodes.
        for node in toposort(
            graph.output, NodeVisitor(), in_graph(graph), allow_cycles=True
        ):
            if self._add_node(node):
                seq.append(node)
        output = []
        # Get code for each node.
        for op in seq:
            op_name = self.local_ref(op)
            if op.is_apply():
                op_code = self.on_apply(op)
            elif op.is_constant_graph():
                op_code = self.on_function(op.value, op)
            elif op.is_constant():
                op_code = self.on_constant(op)
            else:
                raise AssertionError(f"Unsupported node: {op}")
            # If latest op is graph.output, we can write code
            # to return it immediately.
            if op is seq[-1] is graph.output:
                prefix = "return"
            else:
                prefix = f"{op_name} ="
            # Code for a node may be a list of lines of code.
            # In such case, latest code line should be only the return value.
            if isinstance(op_code, list):
                build_code = op_code[:-1]
                return_code = op_code[-1]
                output.extend(build_code)
                output.append(f"{prefix} {return_code}")
            elif op_code:
                output.append(f"{prefix} {op_code}")
        # Output may be empty, for e.g. if function just returns a parameter.
        if not output:
            # I don't know why, but there may be functions that
            # return a value not even reachable from function code
            # (thus, self.ref() might not find node).
            # In such case, let s just put an exception raising in code.
            # If such function is called, then code will fail.
            try:
                output_name = self.ref(graph.output)
            except KeyError:
                output.append(
                    f'raise RuntimeError("Unreachable: {type(graph.output).__name__} {graph.output}")'
                )
            else:
                output.append(f"return {output_name}")

        constants_code = []
        closures_code = []
        for cst_name, cst_val in self.const_name_to_value.items():
            if not self._is_inline_const(cst_val):
                # Write only non-inlinable constants.
                # Other are inlined, hence it's useless to define them.
                constants_code.append(f"{cst_name} = {cst_val}")
        if constants_code:
            constants_code = ["# Constants"] + constants_code + [""]
        for fn_name, (fn_params, fn_body) in self.closure_name_to_code.items():
            fn_signature = f"def {fn_name}({', '.join(fn_params)}):"
            closures_code.append(fn_signature)
            closures_code.append(fn_body)
        # Body code contains constants, then closures, then function code.
        return param_names, constants_code + closures_code + output