Exemple #1
0
    def downgrade(self):
        """
        Downgrade hypergraph diagram to :class:`discopy.rigid.Diagram`.

        Examples
        --------
        >>> x = Ty('x')
        >>> v = Box('v', Ty(), x @ x)
        >>> print((v >> Swap(x, x) >> v[::-1]).downgrade())
        v >> Swap(x, x) >> v[::-1]
        >>> print((Id(x) @ Swap(x, x) >> v[::-1] @ Id(x)).downgrade())
        Id(x) @ Swap(x, x) >> v[::-1] @ Id(x)
        """
        diagram = self.make_progressive()
        graph = Graph()
        graph.add_nodes_from(diagram.ports)
        graph.add_edges_from([(diagram.ports[i], diagram.ports[j])
                              for i, j in enumerate(diagram.bijection)])
        graph.add_nodes_from([
            Node("box",
                 depth=depth,
                 box=box if isinstance(box, rigid.Box) else rigid.Box(
                     box.name,
                     box.dom,
                     box.cod,
                     _dagger=box.is_dagger,
                     data=box.data)) for depth, box in enumerate(diagram.boxes)
        ])
        graph.add_nodes_from([
            Node("box",
                 depth=len(diagram.boxes) + i,
                 box=rigid.Spider(0, 0, diagram.spider_types[s]))
            for i, s in enumerate(diagram.scalar_spiders)
        ])
        return drawing.nx2diagram(graph, rigid.Ty, rigid.Id)
Exemple #2
0
    def ports(self):
        """
        The ports in a diagram.

        Examples
        --------
        >>> x, y, z = types("x y z")
        >>> f, g = Box('f', x, y @ y), Box('g', y @ y, z)
        >>> for port in (f >> g).ports: print(port)
        Node('input', i=0, obj=Ob('x'))
        Node('dom', depth=0, i=0, obj=Ob('x'))
        Node('cod', depth=0, i=0, obj=Ob('y'))
        Node('cod', depth=0, i=1, obj=Ob('y'))
        Node('dom', depth=1, i=0, obj=Ob('y'))
        Node('dom', depth=1, i=1, obj=Ob('y'))
        Node('cod', depth=1, i=0, obj=Ob('z'))
        Node('output', i=0, obj=Ob('z'))
        """
        return [Node("input", i=i, obj=obj) for i, obj in enumerate(self.dom)]\
            + sum([[
                Node(kind, depth=depth, i=i, obj=obj)
                for i, obj in enumerate(typ)]
                for depth, box in enumerate(self.boxes)
                for kind, typ in [("dom", box.dom), ("cod", box.cod)]], [])\
            + [Node("output", i=i, obj=obj) for i, obj in enumerate(self.cod)]
Exemple #3
0
    def draw(self, seed=None, k=.25, path=None):
        """
        Draw a hypegraph diagram.

        Examples
        --------
        >>> x, y, z = types('x y z')
        >>> f = Box('f', x, y @ z)
        >>> f.draw(
        ...     path='docs/_static/imgs/hypergraph/box.png', seed=42)

        .. image:: ../_static/imgs/hypergraph/box.png
            :align: center

        >>> (Spider(2, 2, x) >> f @ Id(x)).draw(
        ...     path='docs/_static/imgs/hypergraph/diagram.png', seed=42)

        .. image:: ../_static/imgs/hypergraph/diagram.png
            :align: center
        """
        graph, pos = self.spring_layout(seed=seed, k=k)
        for i, (dom_wires, cod_wires) in enumerate(self.box_wires):
            box_node = Node("box", i=i)
            for kind, wires in [("dom", dom_wires), ("cod", cod_wires)]:
                for j, spider in enumerate(wires):
                    port_node = Node(kind, i=i, j=j)
                    x, y = pos[box_node]
                    if not isinstance(self.boxes[i], rigid.Spider):
                        y += .25 if kind == "dom" else -.25
                        x -= .25 * (len(wires[:-1]) / 2 - j)
                    pos[port_node] = x, y
        labels = {
            node: self.spider_types[node.i] if node.kind == "spider" else
            self.boxes[node.i].name if node.kind == "box" else ""
            for node in graph.nodes
        }
        nodelist = list(graph.nodes)
        node_size = [
            300 if node.kind in ["spider", "box"] else 0 for node in nodelist
        ]
        draw_networkx(graph,
                      pos=pos,
                      labels=labels,
                      nodelist=nodelist,
                      node_size=node_size,
                      node_color="white",
                      edgecolors="black")
        if path is not None:
            plt.savefig(path)
            plt.close()
        plt.show()
Exemple #4
0
def draw_controlled_gate(backend, positions, node, **params):
    """ Draws a :class:`discopy.quantum.gates.Controlled` gate. """
    box, depth = node.box, node.depth
    dom = Node("dom", obj=box.dom[0], i=0, depth=depth)
    cod = Node("cod", obj=box.cod[0], i=0, depth=depth)
    middle = positions[dom][0], (positions[dom][1] + positions[cod][1]) / 2
    controlled_box = add_drawing_attributes(box.controlled.downgrade())
    controlled = Node("box", box=controlled_box, depth=depth)
    c_dom = Node("dom", obj=box.dom[0], i=1, depth=depth)
    c_cod = Node("cod", obj=box.cod[0], i=1, depth=depth)
    c_middle =\
        positions[c_dom][0], (positions[c_dom][1] + positions[c_cod][1]) / 2
    target = (positions[c_dom][0],
              (positions[c_dom][1] + positions[c_cod][1]) / 2)
    if controlled_box.name == "X":  # CX gets drawn as a circled plus sign.
        backend.draw_wire(positions[c_dom], positions[c_cod])
        eps = 1e-10
        perturbed_target = target[0], target[1] + eps
        backend.draw_node(*perturbed_target,
                          shape="circle",
                          color="white",
                          edgecolor="black",
                          nodesize=2 * params.get("nodesize", 1))
        backend.draw_node(*target,
                          shape="plus",
                          nodesize=2 * params.get("nodesize", 1))
        left_of_target = target
    else:
        left_of_target = c_middle[0] - .25, c_middle[1]
        fake_positions = {
            controlled: target,
            dom: positions[c_dom],
            cod: positions[c_cod]
        }
        backend = draw_box(backend, fake_positions, controlled, **params)
    backend.draw_wire(positions[dom], positions[cod])
    # TODO change bend_in and bend_out for tikz backend
    backend.draw_wire(middle, left_of_target, bend_in=True, bend_out=True)
    backend.draw_node(*middle,
                      color="black",
                      shape="circle",
                      nodesize=params.get("nodesize", 1))
    return backend
Exemple #5
0
def draw_discard(backend, positions, node, **params):
    """ Draws a :class:`discopy.quantum.circuit.Discard` box. """
    box, depth = node.box, node.depth
    left_dom, right_dom = (Node("dom", obj=box.dom[i], i=i, depth=depth)
                           for i in [0, len(box.dom) - 1])
    left, right = (positions[n][0] for n in [left_dom, right_dom])
    left, right = left - .25, right + .25
    height = positions[node][1] + .25
    for i in range(3):
        source = (left + .1 * i, height - .1 * i)
        target = (right - .1 * i, height - .1 * i)
        backend.draw_wire(source, target)
    return backend
Exemple #6
0
def draw_brakets(backend, positions, node, **params):
    """ Draws a :class:`discopy.quantum.gates.Ket` box. """
    box, depth = node.box, node.depth
    is_bra = len(box.dom) > 0
    for i, bit in enumerate(box._digits):
        kind = "dom" if is_bra else "cod"
        obj = box.dom[i] if is_bra else box.cod[i]
        wire = Node(kind, obj=obj, depth=depth, i=i)
        middle = positions[wire]
        left = middle[0] - .25, middle[1]
        right = middle[0] + .25, middle[1]
        top = middle[0], middle[1] + .5
        bottom = middle[0], middle[1] - .5
        backend.draw_polygon(left,
                             right,
                             bottom if is_bra else top,
                             color=box.color)
        backend.draw_text(bit,
                          middle[0],
                          middle[1] + (-.25 if is_bra else .2),
                          ha='center',
                          va='center',
                          fontsize=params.get('fontsize', None))
    return backend
Exemple #7
0
 def spring_layout(self, seed=None, k=None):
     """ Computes planar position using a force-directed layout. """
     if seed is not None:
         random.seed(seed)
     height = len(self.boxes) + self.n_spiders
     width = max(len(self.dom), len(self.cod))
     graph, pos = Graph(), {}
     graph.add_nodes_from(
         Node("spider", i=i) for i in range(self.n_spiders))
     graph.add_edges_from((Node("input", i=i), Node("spider", i=j))
                          for i, j in enumerate(self.wires[:len(self.dom)]))
     for i, (dom_wires, cod_wires) in enumerate(self.box_wires):
         box_node = Node("box", i=i)
         graph.add_node(box_node)
         for case, wires in [("dom", dom_wires), ("cod", cod_wires)]:
             for j, spider in enumerate(wires):
                 spider_node = Node("spider", i=spider)
                 port_node = Node(case, i=i, j=j)
                 graph.add_edge(box_node, port_node)
                 graph.add_edge(port_node, spider_node)
     graph.add_edges_from((Node("output", i=i), Node("spider", i=j))
                          for i, j in enumerate(self.wires[len(self.wires) -
                                                           len(self.cod):]))
     for i, _ in enumerate(self.dom):
         pos[Node("input", i=i)] = (i, height)
     for i, (dom_wires, cod_wires) in enumerate(self.box_wires):
         box_node = Node("box", i=i)
         pos[box_node] = (random.uniform(-width / 2, width / 2),
                          random.uniform(0, height))
         for kind, wires in [("dom", dom_wires), ("cod", cod_wires)]:
             for j, spider in enumerate(wires):
                 pos[Node(kind, i=i, j=j)] = pos[box_node]
     for i in range(self.n_spiders):
         pos[Node("spider", i=i)] = (random.uniform(-width / 2, width / 2),
                                     random.uniform(0, height))
     for i, _ in enumerate(self.cod):
         pos[Node("output", i=i)] = (i, 0)
     fixed = [Node("input", i=i) for i, _ in enumerate(self.dom)
              ] + [Node("output", i=i)
                   for i, _ in enumerate(self.cod)] or None
     pos = spring_layout(graph, pos=pos, fixed=fixed, k=k, seed=seed)
     return graph, pos