예제 #1
0
def test_batchnorm():
    class X(nn.Module):
        def __init__(self, features):
            super(X, self).__init__()
            self.bnorm = nn.BatchNorm1d(features)

        def forward(self, x):
            return self.bnorm(x)

    # Create a module, run it, and get the shape of the returned output.
    m = X(100)
    dummyInput = torch.zeros(20, 100)
    actualOutputShape = list(m(dummyInput).size())

    # jit.script the module and run some passes to shrink the graph.
    # Shrinking the graph is mostly done for debugability.
    m = torch.jit.script(m)
    graph = m.graph
    # pylint: disable=protected-access
    torch._C._jit_pass_inline(graph)
    poptorch.poptorch_core.peepholeOptimizations(graph, False)
    graph, _ = torch._C._jit_pass_lower_graph(graph, m._c)
    torch._C._jit_pass_constant_propagation(graph)
    # pylint: enable=protected-access
    # Observe the graph doesn't already have a shape for the output.
    assert _getOutputShape(graph) is None

    # Run shape analysis on the graph
    helpers.propagateInputShapes(graph, (dummyInput, ))
    inferedOutputShape = _getOutputShape(graph)

    assert _has_node(graph, 'aten::batch_norm')
    assert inferedOutputShape == actualOutputShape
예제 #2
0
def test_maxpool2d():
    class X(nn.Module):
        def __init__(self, *args, **kwargs):
            super(X, self).__init__()
            self.pool = nn.MaxPool2d(*args, **kwargs)

        def forward(self, x):
            return self.pool(x)

    # Create a module, run it, and get the shape of the returned output.
    m = X(3, stride=2)
    dummyInput = torch.zeros(20, 16, 50, 32)
    actualOutputShape = list(m(dummyInput).size())

    # jit.script the module and run some passes to shrink the graph.
    # Shrinking the graph is mostly done for debugability.
    m = torch.jit.script(m)
    graph = m.graph
    # pylint: disable=protected-access
    torch._C._jit_pass_inline(graph)
    torch._C._jit_pass_constant_propagation(graph)
    graph, _ = torch._C._jit_pass_lower_graph(graph, m._c)
    # pylint: enable=protected-access
    # Observe the graph doesn't already have a shape for the output.
    assert _getOutputShape(graph) is None

    # Run shape analysis on the graph
    helpers.propagateInputShapes(graph, (dummyInput, ))
    inferedOutputShape = _getOutputShape(graph)

    assert _has_node(graph, 'aten::max_pool2d')
    assert inferedOutputShape == actualOutputShape
예제 #3
0
    def run_test(input_shape):
        class X(nn.Module):
            def forward(self, x):
                return torch.flatten(x)

        m = X()
        dummyInputs = (torch.zeros(*input_shape), )
        actualOutputShape = list(m(*dummyInputs).size())

        m = torch.jit.script(m)
        graph = m.graph
        list(graph.inputs())[1].inferTypeFrom(dummyInputs[0])
        # pylint: disable=protected-access
        graph, _ = torch._C._jit_pass_lower_graph(graph, m._c)
        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_constant_propagation(graph)
        # pylint: enable=protected-access
        assert _getOutputShape(graph) is None

        helpers.propagateInputShapes(graph, dummyInputs)
        inferedOutputShape = _getOutputShape(graph)
        print(graph)

        assert _has_node(graph, 'aten::flatten')
        assert inferedOutputShape == actualOutputShape
예제 #4
0
    def run_test(input_shape):
        class X(nn.Module):
            def __init__(self):
                super(X, self).__init__()
                self.aap = nn.AdaptiveAvgPool2d((5, 7))

            def forward(self, x):
                return self.aap(x)

        m = X()
        dummyInputs = (torch.zeros(*input_shape), )
        actualOutputShape = list(m(*dummyInputs).size())

        m = torch.jit.script(m)
        graph = m.graph
        list(graph.inputs())[1].inferTypeFrom(dummyInputs[0])
        # pylint: disable=protected-access
        graph, _ = torch._C._jit_pass_lower_graph(graph, m._c)
        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_constant_propagation(graph)

        torch._C._jit_pass_loop_unrolling(graph)
        torch._C._jit_pass_constant_propagation(graph)
        # pylint: enable=protected-access
        poptorch.poptorch_core.eliminateListConstructs(graph)

        assert _getOutputShape(graph) is None

        helpers.propagateInputShapes(graph, dummyInputs)
        inferedOutputShape = _getOutputShape(graph)
        print(graph)

        assert _has_node(graph, 'aten::adaptive_avg_pool2d')
        assert inferedOutputShape == actualOutputShape
예제 #5
0
def test_addmm():
    class X(nn.Module):
        def forward(self, x, y, z):
            return torch.addmm(x, y, z)

    m = X()
    dummyInputs = (torch.zeros(2, 4), torch.zeros(2, 3), torch.zeros(3, 4))
    actualOutputShape = list(m(*dummyInputs).size())

    m = torch.jit.script(m)
    graph = m.graph
    # pylint: disable=protected-access
    graph, _ = torch._C._jit_pass_lower_graph(graph, m._c)
    # pylint: enable=protected-access
    assert _getOutputShape(graph) is None

    helpers.propagateInputShapes(graph, dummyInputs)
    inferedOutputShape = _getOutputShape(graph)

    assert _has_node(graph, 'aten::addmm')
    assert inferedOutputShape == actualOutputShape
예제 #6
0
def test_view():
    class X(nn.Module):
        def forward(self, x):
            return x.view(50, -1)

    m = X()
    dummyInput = torch.zeros(100, 100)
    actualOutputShape = list(m(dummyInput).size())

    print(actualOutputShape)
    m = torch.jit.script(m)
    graph = m.graph
    # pylint: disable=protected-access
    graph, _ = torch._C._jit_pass_lower_graph(graph, m._c)
    torch._C._jit_pass_constant_propagation(graph)
    # pylint: enable=protected-access
    assert _getOutputShape(graph) is None

    helpers.propagateInputShapes(graph, (dummyInput, ))
    inferedOutputShape = _getOutputShape(graph)

    assert _has_node(graph, 'aten::view')
    assert inferedOutputShape == actualOutputShape
예제 #7
0
    def run_test(shape_a, shape_b):
        print(f'run_test({shape_a}, {shape_b})')

        class X(nn.Module):
            def forward(self, a, b):
                return a + b

        m = X()
        dummyInputs = [torch.zeros(*shape) for shape in (shape_a, shape_b)]
        actualOutputShape = list(m(*dummyInputs).size())

        m = torch.jit.script(m)
        graph = m.graph
        # pylint: disable=protected-access
        graph, _ = torch._C._jit_pass_lower_graph(graph, m._c)
        # pylint: enable=protected-access
        assert _getOutputShape(graph) is None

        helpers.propagateInputShapes(graph, dummyInputs)
        inferedOutputShape = _getOutputShape(graph)
        print(graph)

        assert _has_node(graph, 'aten::add')
        assert inferedOutputShape == actualOutputShape