def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout):
        if layout[1] == "C":
            shape = [1, 3] + [10] * ndim
            wshape = [8, 3] + [3] * ndim
        elif layout[-1] == "C":
            shape = [1] + [10] * ndim + [3]
            wshape = [8] + [3] * ndim + [3]
        else:
            raise ValueError("This test only supports NC* and N*C")

        x = relay.var("x", shape=shape, dtype="float32")
        w = relay.var("w", shape=wshape, dtype="float32")
        pad = relay.nn.pad(x, pad_width, pad_value, pad_mode)
        if layout[1] == "C":
            conv = convs[ndim - 1](pad, w, padding=orig_padding)
        else:
            conv = convs[ndim - 1](
                pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
            )

        if pad_mode == "constant" and pad_value == 0:
            new_padding = []
            for j in range(2):
                for i in range(len(pad_width)):
                    if layout[i] in ["D", "H", "W"]:
                        new_padding.append(pad_width[i][j])
            for i in range(len(new_padding)):
                new_padding[i] += orig_padding[i]
            if layout[1] == "C":
                after = convs[ndim - 1](x, w, padding=new_padding)
            else:
                after = convs[ndim - 1](
                    x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
                )
        else:
            after = conv

        zz = run_opt_pass(conv, transform.FoldExplicitPadding())
        expected = run_opt_pass(after, transform.InferType())
        assert tvm.ir.structural_equal(zz, expected)

        mod1 = tvm.IRModule.from_expr(conv)
        mod2 = tvm.IRModule.from_expr(zz)

        with tvm.transform.PassContext():
            func1 = relay.create_executor(
                "vm", mod=mod1, device=tvm.cpu(), target="llvm"
            ).evaluate()
        func2 = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm").evaluate()
        x_np = np.random.rand(*shape).astype("float32")
        w_np = np.random.rand(*wshape).astype("float32")
        result1 = func1(x_np, w_np)
        result2 = func2(x_np, w_np)

        tvm.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-5, atol=1e-5)
Beispiel #2
0
    def validate(
        pools,
        ndim,
        pad_width,
        pad_value,
        orig_padding,
        layout,
        pool_size,
        pad_mode="constant",
        dtype="float32",
        no_fold=False,
        **kwargs,
    ):
        pad_value_const = relay.const(pad_value, dtype=dtype)

        if layout[1] == "C":
            shape = [1, 3] + [10] * ndim
        elif layout[-1] == "C":
            shape = [1] + [10] * ndim + [3]
        else:
            raise ValueError("This test only supports NC* and N*C")

        x = relay.var("x", shape=shape, dtype=dtype)
        pad = relay.nn.pad(x, pad_width, pad_value_const, pad_mode)
        if layout[1] == "C":
            pool = pools[ndim - 1](pad,
                                   padding=orig_padding,
                                   pool_size=pool_size,
                                   **kwargs)
        else:
            pool = pools[ndim - 1](pad,
                                   padding=orig_padding,
                                   layout=layout,
                                   pool_size=pool_size,
                                   **kwargs)

        if pools == max_pools:
            foldable_pad_value = get_min_value(dtype)
        else:
            foldable_pad_value = 0

        if pad_mode == "constant" and pad_value == foldable_pad_value:
            new_padding = []
            for j in range(2):
                for i in range(len(pad_width)):
                    if layout[i] in ["D", "H", "W"]:
                        new_padding.append(pad_width[i][j])
            for i in range(len(new_padding)):
                new_padding[i] += orig_padding[i]

            if pools == avg_pools and all(v == 0 for v in orig_padding):
                # If the orig padding for AvgPool is all zero and the pad op to fold
                # has non-zero pad width, the resultant folded AvgPool will have
                # count_include_pad=True so AvgPool's divisor is agnostic of pad boundaries
                kwargs["count_include_pad"] = True
            if layout[1] == "C":
                after = pools[ndim - 1](x,
                                        padding=new_padding,
                                        pool_size=pool_size,
                                        **kwargs)
            else:
                after = pools[ndim - 1](x,
                                        padding=new_padding,
                                        layout=layout,
                                        pool_size=pool_size,
                                        **kwargs)
        else:
            after = pool

        zz = run_opt_pass(pool, transform.FoldExplicitPadding())
        expected = run_opt_pass(after, transform.InferType())

        assert tvm.ir.structural_equal(zz, expected)

        mod1 = tvm.IRModule.from_expr(pool)
        mod2 = tvm.IRModule.from_expr(zz)

        if not no_fold:
            op_freqs = relay.analysis.list_op_freqs(mod2)
            assert "nn.pad" not in op_freqs

        with tvm.transform.PassContext():
            func1 = relay.create_executor("vm",
                                          mod=mod1,
                                          device=tvm.cpu(),
                                          target="llvm").evaluate()

        func2 = relay.create_executor("vm",
                                      mod=mod2,
                                      device=tvm.cpu(),
                                      target="llvm").evaluate()
        x_np = np.random.rand(*shape).astype(dtype)

        result1 = func1(x_np)
        result2 = func2(x_np)

        tvm.testing.assert_allclose(result1.numpy(),
                                    result2.numpy(),
                                    rtol=1e-5,
                                    atol=1e-5)