Ejemplo n.º 1
    def test_type_check_conv2D_types(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes, stride=1):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)

            def forward(self, x: Dyn):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.conv1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)

        for n in traced.graph.nodes:
            if n.op == 'call_module':
                assert isinstance(n.type.__args__[2], sympy.floor)
                assert isinstance(n.type.__args__[3], sympy.floor)
Ejemplo n.º 2
    def test_type_check_conv2D(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes, stride=1, norm_layer=None):
                super(BasicBlock, self).__init__()
                if norm_layer is None:
                    norm_layer = torch.nn.BatchNorm2d
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)

            def forward(self, x: Dyn):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.conv1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        for n in graph.nodes:
            if n.op == 'placeholder':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'call_function':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'output':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'call_module':
                assert n.type == TensorType((2, 2, Dyn, 4))
Ejemplo n.º 3
    def test_type_check_batch_norm_symbolic(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                self.bn1 = norm_layer(planes)

            def forward(self, x: Dyn):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.bn1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)


        my_types = iter([
            TensorType[(2, 2, sympy.symbols('~7'), 4)],
            TensorType[(2, 2, sympy.symbols('~7'), 4)],
            TensorType[(2, 2, sympy.symbols('~7'), 4)],
            TensorType[(2, 2, sympy.symbols('~7'), 4)]

        for n in graph.nodes:
            assert n.type == next(my_types)
Ejemplo n.º 4
    def test_conv_reshape_add(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, in_planes, out_planes, kernel_size, stride,
                         padding, groups, dilation):
                super(BasicBlock, self).__init__()
                self.conv1 = torch.nn.Conv2d(in_channels=in_planes,

            def forward(self, x: Dyn, y: Dyn):
                return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))),

        B = BasicBlock(2, 2, 2, 3, 2, 2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")

        generator = ConstraintGenerator(traced)
        new_constraints, counter = generator.generate_constraints(0)
        assert len(new_constraints.conjucts) == 16
Ejemplo n.º 5
    def test_flatten_fully_static(self):
        annotation_list = [
            TensorType((2, 5, 6, 9)),
            TensorType((10, 15, 13, 14)),
            TensorType((10, Dyn, 13, 14)),
            TensorType((Dyn, Dyn, Dyn, 10))
        input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                      (10, 15, 13, 14), (2, 2, 10, 10)]

        intermediate_list = [
            Dyn, (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14),
            (2, 2, 10, 10)

        start_dim = [1, 2, 1, 2, 0]
        end_dim = [1, 3, 3, 3, -2]

        for i in range(5):
            annotation = annotation_list[i]
            input = input_list[i]

            # intermediate_type = intermediate_list[i]

            class BasicBlock(torch.nn.Module):
                def __init__(self, start, end):
                    super(BasicBlock, self).__init__()
                    self.start = start
                    self.end = end

                def forward(self, x):
                    out = torch.flatten(x, self.start, self.end)
                    return out

            B = BasicBlock(start_dim[i], end_dim[i])
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = annotation

            b = B.forward(torch.rand(input))
            tc = GraphTypeChecker({}, traced)

            for n in graph.nodes:
                if n.op == 'output':
                    assert is_consistent(n.type, TensorType(b.size()))
Ejemplo n.º 6
    def test_add_reshape(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self):
                super(BasicBlock, self).__init__()

            def forward(self, x: Dyn, y: Dyn):
                return torch.add(torch.reshape(x, (1, 2)),
                                 torch.reshape(y, (2, 2)))

        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(BasicBlock())
        traced = GraphModule(ast_rewriter.root, graph, "gm")

        generator = ConstraintGenerator(traced)
        new_constraints, counter = generator.generate_constraints(0)
        assert len(new_constraints.conjucts) == 11
Ejemplo n.º 7
    def test_type_check_conv2D_maxpool2d_flatten(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self):
                super(BasicBlock, self).__init__()

                self.conv1 = torch.nn.Conv2d(3, 6, 5)
                self.pool = torch.nn.MaxPool2d(2, 2)
                self.conv2 = torch.nn.Conv2d(6, 16, 5)
                self.fc1 = torch.nn.Linear(5, 120)
                self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))

            def forward(self, x: TensorType((4, 3, 32, 32))):
                out = self.conv1(x)
                out = self.pool(out)
                out = self.conv2(out)
                out = self.pool(out)
                out = self.fc1(out)
                out = self.pool2(out)
                out = torch.flatten(out, 1)
                return out

        B = BasicBlock()
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)

        expected_ph_types = [
            TensorType((4, 3, 32, 32)),
            TensorType((4, 6, 28, 28)),
            TensorType((4, 6, 14, 14)),
            TensorType((4, 16, 10, 10)),
            TensorType((4, 16, 5, 5)),
            TensorType((4, 16, 5, 120)),
            TensorType((4, 16, 6, 7)),
            TensorType((4, 672)),
            TensorType((4, 672))

        expected_iter = iter(expected_ph_types)

        for n in traced.graph.nodes:
            assert n.type == next(expected_iter)
Ejemplo n.º 8
    def test_subgraph_rewriter_annotations_int(self):
        class M1(torch.nn.Module):
            def forward(self, x):
                y: int = x
                return torch.add(x, y)

        class M2(torch.nn.Module):
            def forward(self, x):
                y = annotate(x, int)
                return torch.add(x, y)

        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(M1())

        module = M2()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
            if n.op == 'placeholder':
                assert n.type == int
                assert m.type == int
Ejemplo n.º 9
    def test_type_typechecl_maxpool2d_3dinput(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self):
                super(BasicBlock, self).__init__()
                self.pool = torch.nn.MaxPool2d(5, 8)

            def forward(self, x: TensorType((64, 8, 8))):
                out = self.pool(x)
                return out

        B = BasicBlock()
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)

        for n in traced.graph.nodes:
            if n.target == 'output':
                assert n.type == TensorType((64, 1, 1))
Ejemplo n.º 10
    def test_type_check_batch_norm_2D_false(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                self.bn1 = norm_layer(planes)

            def forward(self, x: TensorType((2, 2, 5))):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.bn1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        with self.assertRaises(TypeError):
Ejemplo n.º 11
    def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self):
                super(BasicBlock, self).__init__()

                self.conv1 = torch.nn.Conv2d(3, 6, 5)
                self.pool = torch.nn.MaxPool2d(2, 2)
                self.conv2 = torch.nn.Conv2d(6, 16, 5)
                self.fc1 = torch.nn.Linear(5, 120)
                self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))

            def forward(self, x: TensorType((4, 3, Dyn, Dyn))):
                out = self.conv1(x)
                out = self.pool(out)
                out = self.conv2(out)
                out = self.pool(out)
                out = self.fc1(out)
                out = self.pool2(out)
                out = torch.flatten(out, 1)
                return out

        B = BasicBlock()
        ast_rewriter = RewritingTracer()
        traced = symbolic_trace(B)
        tc = GraphTypeChecker({}, traced)

        for n in traced.graph.nodes:
            if n.target == 'conv1':
                assert n.type == TensorType(
                    (4, 6, sympy.floor((sympy.symbols('~0') - 4)),
                     sympy.floor((sympy.symbols('~1') - 4))))

            elif n.target == 'conv2':
                assert n.type == TensorType(
                    (4, 16, sympy.floor((sympy.symbols('~4') - 4)),
                     sympy.floor((sympy.symbols('~5') - 4))))
Ejemplo n.º 12
    def test_type_check_conv2D_2(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes, stride=1, norm_layer=None):
                super(BasicBlock, self).__init__()
                if norm_layer is None:
                    norm_layer = torch.nn.BatchNorm2d
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)

            def forward(self, x: TensorType((5, 2, 3, 4))):
                identity = x
                out = self.conv1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        b = B.forward(torch.rand(5, 2, 3, 4))

        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        t = TensorType((5, 2, 3, 4))
        for n in graph.nodes:
            if n.op == 'placeholder':
                assert n.type == t
            if n.op == 'call_function':
                assert n.type == t
            if n.op == 'output':
                assert torch.Size(n.type.__args__) == b.shape
            if n.op == 'call_module':
                assert n.type == t

        B = BasicBlock(1, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        with self.assertRaises(TypeError):
Ejemplo n.º 13
    def test_type_check_batch_norm_2D_broadcast(self):
        class BasicBlock(torch.nn.Module):
            def __init__(self, inplanes, planes):
                super(BasicBlock, self).__init__()
                norm_layer = torch.nn.BatchNorm2d
                self.bn1 = norm_layer(planes)

            def forward(self, x: Dyn):
                identity = x
                out: TensorType((2, 2, Dyn, 4)) = self.bn1(x)
                out += identity
                return out

        B = BasicBlock(2, 2)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        for n in graph.nodes:
            if n.op == 'placeholder':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'call_function':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'output':
                assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
            if n.op == 'call_module':
                assert n.type == TensorType((2, 2, Dyn, 4))

        B = BasicBlock(1, 1)
        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")
        tc = GraphTypeChecker({}, traced)
        with self.assertRaises(TypeError):
Ejemplo n.º 14
    def test_type_check_conv2D_2_fully_static(self):
        annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                           (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)]
        input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                      (10, 15, 13, 14), (1, 2, 2, 3)]
        intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6),
                              (10, 15, Dyn, 5), (10, 15, 7, 7),
                              (1, Dyn, Dyn, Dyn)]
        in_planes_list = [2, 5, 15, 15, 2]
        stride_list = [1, 2, 3, 2, 2]
        out_planes_list = [2, 5, 15, 15, 2]
        groups_list = [1, 5, 5, 5, 2]
        dilation_list = [1, 2, 3, 3, 3]
        padding_list = [1, 2, 3, 3, 3]
        kernel_size_list = [1, 2, 3, 3, 3]
        output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5),
                        (10, 15, 7, 7), (1, 2, Dyn, Dyn)]

        for i in range(5):
            annotation = annotation_list[i]
            input = input_list[i]
            in_planes = in_planes_list[i]
            stride = stride_list[i]
            out_planes = out_planes_list[i]
            groups = groups_list[i]
            dilation = dilation_list[i]
            padding = padding_list[i]
            kernel_size = kernel_size_list[i]
            intermediate_type = intermediate_types[i]

            class BasicBlock(torch.nn.Module):
                def __init__(self, in_planes, out_planes, kernel_size, stride,
                             padding, groups, dilation):
                    super(BasicBlock, self).__init__()
                    self.conv1 = torch.nn.Conv2d(in_channels=in_planes,

                def forward(self, x):
                    out = self.conv1(x)
                    return out

            B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding,
                           groups, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = TensorType(annotation)

            b = B.forward(torch.rand(input))
            tc = GraphTypeChecker({}, traced)

            for n in graph.nodes:
                if n.op == 'output':
                    assert is_consistent(n.type, TensorType(b.size()))

            # test with intermediate annotations
            class BasicBlock(torch.nn.Module):
                def __init__(self, in_planes, out_planes, kernel_size, stride,
                             padding, groups, dilation):
                    super(BasicBlock, self).__init__()
                    self.conv1 = torch.nn.Conv2d(in_channels=in_planes,

                def forward(self, x):
                    out = self.conv1(x)
                    return out

            B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding,
                           groups, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # populate our intermediate notes
            for n in traced.graph.nodes:
                if n.op == 'call_module':
                    n.type = TensorType(intermediate_type)

            tc = GraphTypeChecker({}, traced)

            for n in traced.graph.nodes:
                if n.op == 'output':
                    assert n.type == TensorType(output_types[i])
                    assert is_consistent(n.type, TensorType(b.size()))
Ejemplo n.º 15
    def test_typecheck_basicblock(self):
        class BasicBlock(torch.nn.Module):
            expansion = 1

            def __init__(self,
                super(BasicBlock, self).__init__()
                if norm_layer is None:
                    norm_layer = torch.nn.BatchNorm2d
                if groups != 1 or base_width != 64:
                    raise ValueError(
                        'BasicBlock only supports groups=1 and base_width=64')
                if dilation > 1:
                    raise NotImplementedError(
                        "Dilation > 1 not supported in BasicBlock")
                # Both self.conv1 and self.downsample layers downsample the input when stride != 1
                self.conv1 = conv3x3(inplanes, planes, stride)
                self.bn1 = norm_layer(planes)
                self.relu = torch.nn.ReLU(inplace=True)
                self.conv2 = conv3x3(planes, planes)
                self.bn2 = norm_layer(planes)
                self.downsample = downsample
                self.stride = stride

            def forward(self, x: TensorType((2, 2, 4, 5))):
                identity = x

                out = self.conv1(x)
                out = self.bn1(out)
                out = self.relu(out)

                out = self.conv2(out)
                out = self.bn2(out)

                if self.downsample is not None:
                    identity = self.downsample(x)

                out += identity
                out = self.relu(out)

                return out

        B = BasicBlock(2, 2)

        ast_rewriter = RewritingTracer()
        graph = ast_rewriter.trace(B)
        traced = GraphModule(ast_rewriter.root, graph, "gm")

        tc = GraphTypeChecker({}, traced)

        for n in traced.graph.nodes:
            if n.target == 'output':
                assert isinstance(n.type, TensorType)
                assert torch.Size(n.type.__args__) == B.forward(
                    torch.rand(2, 2, 4, 5)).size()
Ejemplo n.º 16
def symbolic_trace_with_rewrite(
        root: Union[torch.nn.Module, Callable]) -> GraphModule:
    return GraphModule(
        root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
Ejemplo n.º 17
    def test_type_maxpool2d_fully_static(self):
        annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                           (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)]
        input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                      (10, 15, 13, 14), (2, 2, 10, 10)]
        intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4),
                              (10, 15, Dyn, 2), (10, 15, 2, 3),
                              (2, Dyn, Dyn, Dyn)]
        stride_list = [1, 2, 3, 2, 1]
        dilation_list = [1, 2, 3, 3, 2]
        padding_list = [1, 2, 3, 3, 1]
        kernel_size_list = [2, 4, 6, 6, 3]
        output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2),
                        (10, 15, 2, 3), (2, Dyn, Dyn, 8)]

        for i in range(5):
            annotation = annotation_list[i]
            input = input_list[i]
            stride = stride_list[i]
            dilation = dilation_list[i]
            padding = padding_list[i]
            kernel_size = kernel_size_list[i]
            intermediate_type = intermediate_types[i]

            class BasicBlock(torch.nn.Module):
                def __init__(self, kernel_size, stride, padding, dilation):
                    super(BasicBlock, self).__init__()
                    self.pool = torch.nn.MaxPool2d(kernel_size,

                def forward(self, x):
                    out = self.pool(x)
                    return out

            B = BasicBlock(kernel_size, stride, padding, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = TensorType(annotation)

            b = B.forward(torch.rand(input))
            tc = GraphTypeChecker({}, traced)

            for n in graph.nodes:
                if n.op == 'output':
                    assert is_consistent(n.type, TensorType(b.size()))

            # test with intermediate annotations
            class BasicBlock(torch.nn.Module):
                def __init__(self, kernel_size, stride, padding, dilation):
                    super(BasicBlock, self).__init__()
                    self.pool = torch.nn.MaxPool2d(kernel_size,

                def forward(self, x):
                    out = self.pool(x)
                    return out

            B = BasicBlock(kernel_size, stride, padding, dilation)
            ast_rewriter = RewritingTracer()
            graph = ast_rewriter.trace(B)
            traced = GraphModule(ast_rewriter.root, graph, "gm")

            # annotate our argument
            for n in graph.nodes:
                if n.op == 'placeholder':
                    n.type = TensorType(annotation)

            # populate our intermediate notes
            for n in traced.graph.nodes:
                if n.op == 'call_module':
                    n.type = TensorType(intermediate_type)

            tc = GraphTypeChecker({}, traced)

            for n in traced.graph.nodes:
                if n.op == 'output':
                    assert n.type == TensorType(output_types[i])
                    assert is_consistent(n.type, TensorType(b.size()))