Beispiel #1
0
    def test_resnet_blocks(self):
        # Resnet101
        model = torchvision.models.resnet101()

        transforms = [
            # Fold Conv, BN, RELU layers into one
            ht.Fold("Conv > BatchNormalization > Relu", "ConvBnRelu"),
            # Fold Conv, BN layers together
            ht.Fold("Conv > BatchNormalization", "ConvBn"),
            # Fold bottleneck blocks
            ht.Fold(
                """
                ((ConvBnRelu > ConvBnRelu > ConvBn) | ConvBn) > Add > Relu
                """, "BottleneckBlock", "Bottleneck Block"),
            # Fold residual blocks
            ht.Fold("""ConvBnRelu > ConvBnRelu > ConvBn > Add > Relu""",
                    "ResBlock", "Residual Block"),
            # Fold repeated blocks
            ht.FoldDuplicates(),
        ]

        # Display graph using the transforms above
        g = hl.build_graph(model,
                           torch.zeros([1, 3, 224, 224]),
                           transforms=transforms)
        g.save(os.path.join(OUTPUT_DIR, "pytorch_resnet_bloks.pdf"))

        # Clean up
        shutil.rmtree(OUTPUT_DIR)
Beispiel #2
0
    def test_fold(self):
        g = hl.Graph()
        a = hl.Node(uid="a", name="a", op="a")
        b = hl.Node(uid="b", name="b", op="b")
        c = hl.Node(uid="c", name="c", op="c")
        d = hl.Node(uid="d", name="d", op="d")
        g.add_node(a)
        g.add_node(b)
        g.add_node(c)
        g.add_node(d)
        g.add_edge(a, b)
        g.add_edge(b, c)
        g.add_edge(b, d)

        t = ht.Fold("a > b", "ab")
        g = t.apply(g)
        self.assertEqual(g.incoming(g["c"])[0].op, "ab")
Beispiel #3
0
    def test_parallel_fold(self):
        g = hl.Graph()
        a = hl.Node(uid="a", name="a", op="a")
        b = hl.Node(uid="b", name="b", op="b")
        c = hl.Node(uid="c", name="c", op="c")
        d = hl.Node(uid="d", name="d", op="d")
        e = hl.Node(uid="e", name="e", op="e")
        g.add_node(a)
        g.add_node(b)
        g.add_node(c)
        g.add_node(d)
        g.add_node(e)
        g.add_edge(a, b)
        g.add_edge(b, c)
        g.add_edge(a, d)
        g.add_edge(c, e)
        g.add_edge(d, e)

        t = ht.Fold("((b > c) | d) > e", "bcde")
        g = t.apply(g)
        self.assertEqual(g.outgoing(g["a"])[0].op, "bcde")