Ejemplo n.º 1
0
    def test_symbolic_add_with_broadcast(self):
        class M(torch.nn.Module):
            def forward(self, x: TensorType((1, 2, 3, Dyn)),
                        y: TensorType((2, 3, 4))):
                return torch.add(x, y)

        module = M()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        tc = GraphTypeChecker({}, symbolic_traced)
        tc.type_check()
        infer_symbolic_types(symbolic_traced)
        r = Refine(symbolic_traced)
        r.refine()

        assert r.constraints == [
            Equality(1, 1), Equality(2, 2),
            Equality(3, 3)
        ]
        # note that there is no equality constraint between dyn and 4 because
        # dyn could be 4 or 1

        infer_symbolic_types(symbolic_traced)

        expected_ph_types = [
            TensorType((1, 2, 3, sympy.symbols('~0'))),
            TensorType((2, 3, 4)),
            TensorType((1, 2, 3, sympy.symbols('~1'))),
            TensorType((1, 2, 3, sympy.symbols('~1')))
        ]
        expected_iter = iter(expected_ph_types)

        for n in symbolic_traced.graph.nodes:
            assert n.type == next(expected_iter)
Ejemplo n.º 2
0
    def test_type_check_batch_norm_symbolic(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                self.bn1 = norm_layer(planes)

            def forward(self, x: Dyn):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.bn1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        tc.type_check()

        infer_symbolic_types(traced)

        my_types = iter([
            TensorType[(2, 2, sympy.symbols('~7'), 4)],
            TensorType[(2, 2, sympy.symbols('~7'), 4)],
            TensorType[(2, 2, sympy.symbols('~7'), 4)],
            TensorType[(2, 2, sympy.symbols('~7'), 4)]
        ])

        for n in graph.nodes:
            assert n.type == next(my_types)
Ejemplo n.º 3
0
    def test_type_check_conv2D_types(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes, stride=1):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)

            def forward(self, x: Dyn):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.conv1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        tc.type_check()
        infer_symbolic_types(traced)

        for n in traced.graph.nodes:
            if n.op == 'call_module':
                assert isinstance(n.type.__args__[2], sympy.floor)
                assert isinstance(n.type.__args__[3], sympy.floor)
Ejemplo n.º 4
0
    def test_resnet50(self):
        gm_run = symbolic_trace(resnet50())
        sample_input = torch.randn(1, 3, 224, 224)

        # run our nodes
        ShapeProp(gm_run).propagate(sample_input)

        gm_static = symbolic_trace(resnet50())

        for n in gm_static.graph.nodes:
            n.type = None

        g = GraphTypeChecker({}, gm_static)
        g.type_check()
        gm_static.graph.eliminate_dead_code()
        gm_run.graph.eliminate_dead_code()
        # here we are checking for consistency with fully dynamic nodes
        for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes):
            assert is_consistent(n1.type,
                                 TensorType(n2.meta['tensor_meta'].shape))

        # here we give the same input as to runtume
        gm_static_with_types = symbolic_trace(resnet50())

        # we initialize our placeholder
        for n in gm_static_with_types.graph.nodes:
            if n.op == 'placeholder':
                n.type = TensorType((1, 3, 224, 224))

        g = GraphTypeChecker({}, gm_static_with_types)
        g.type_check()
        for n1, n2 in zip(gm_static_with_types.graph.nodes,
                          gm_run.graph.nodes):
            assert n1.type == TensorType(n2.meta['tensor_meta'].shape)

        # apply shape inference to graph and check
        # that the batch size is equal across all layers
        infer_symbolic_types(gm_static)

        batch_sizes = set()
        gm_static.graph.eliminate_dead_code()
        for n in gm_static.graph.nodes:
            assert isinstance(n.type, TensorType)
            batch_sizes.add(n.type.__args__[0])
        assert (len(batch_sizes) == 1)
Ejemplo n.º 5
0
    def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self):
                super(BasicBlock, self).__init__()

                self.conv1 = torch.nn.Conv2d(3, 6, 5)
                self.pool = torch.nn.MaxPool2d(2, 2)
                self.conv2 = torch.nn.Conv2d(6, 16, 5)
                self.fc1 = torch.nn.Linear(5, 120)
                self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))

            def forward(self, x: TensorType((4, 3, Dyn, Dyn))):
                out = self.conv1(x)
                out = self.pool(out)
                out = self.conv2(out)
                out = self.pool(out)
                out = self.fc1(out)
                out = self.pool2(out)
                out = torch.flatten(out, 1)
                return out

        B = BasicBlock()
        ast_rewriter = RewritingTracer()
        traced = symbolic_trace(B)
        tc = GraphTypeChecker({}, traced)
        tc.type_check()
        infer_symbolic_types(traced)

        for n in traced.graph.nodes:
            if n.target == 'conv1':
                assert n.type == TensorType(
                    (4, 6, sympy.floor((sympy.symbols('~0') - 4)),
                     sympy.floor((sympy.symbols('~1') - 4))))

            elif n.target == 'conv2':
                assert n.type == TensorType(
                    (4, 16, sympy.floor((sympy.symbols('~4') - 4)),
                     sympy.floor((sympy.symbols('~5') - 4))))
Ejemplo n.º 6
0
    def test_symbolic_add_with_broadcast_2(self):
        class M(torch.nn.Module):
            def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))):
                return torch.add(x, y)

        module = M()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        tc = GraphTypeChecker({}, symbolic_traced)
        tc.type_check()
        infer_symbolic_types(symbolic_traced)
        r = Refine(symbolic_traced)
        r.refine()

        expected_ph_types = [
            TensorType((1, 2)),
            TensorType((sympy.symbols('~1'), 2)),
            TensorType((sympy.symbols('~1'), 2)),
            TensorType((sympy.symbols('~1'), 2))
        ]
        expected_iter = iter(expected_ph_types)

        for n in symbolic_traced.graph.nodes:
            assert n.type == next(expected_iter)