def test_transpose_network(self, batch_size, channels, height, width, seed,
                            kernel):
     net = core.Net("net")
     net.Conv(["X", "w1", "b1"], ["c1"], stride=1, pad=0, kernel=kernel)
     net.Conv(["X", "w2", "b2"], ["c2"], stride=1, pad=0, kernel=kernel)
     # c1 and c2: batch_size, 2*channels, height - kernel + 1, width - kernel + 1
     net.Conv(["c1", "w3", "b3"], ["c3"], stride=1, pad=0, kernel=kernel)
     net.Conv(["c1", "w4", "b4"], ["c4"], stride=1, pad=0, kernel=kernel)
     # c3 and c4: batch_size, 2*channels, height - 2*kernel + 2, width - 2*kernel + 2
     net.Flatten(["c3"], "c3f")
     net.Flatten(["c4"], "c4f")
     net.Flatten(["X"], "Xf")
     net.Concat(["c3f", "c4f", "Xf"], ["out", "split_info"],
                axis=1,
                add_axis=0)
     np.random.seed(seed)
     workspace.ResetWorkspace()
     tu.randBlobFloat32("X", batch_size, channels, height, width)
     tu.randBlobsFloat32(["w1", "w2"], 2 * channels, channels, kernel,
                         kernel)
     tu.randBlobsFloat32(["b1", "b2"], 2 * channels)
     tu.randBlobsFloat32(["w3", "w4"], 4 * channels, 2 * channels, kernel,
                         kernel)
     tu.randBlobsFloat32(["b3", "b4"], 4 * channels)
     all_inp_names = ["X", "w1", "w2", "b1", "b2", "w3", "w4", "b3", "b4"]
     all_input = workspace.FetchBlobs(all_inp_names)
     workspace.RunNetOnce(net)
     preTransformC1 = workspace.FetchBlob("c1")
     preTransformC3 = workspace.FetchBlob("c3")
     preTransformOut = workspace.FetchBlob("out")
     nn = ng.NNModule(net)
     preTransformNumOperators = len(nn.operators)
     preTransformNumTensors = len(nn.tensors)
     transpose_network(nn)
     new_netdef = nn.convertToCaffe2Proto()
     postTransformNumOperators = len(nn.operators)
     postTransformNumTensors = len(nn.tensors)
     # The minimal number of additional operators and tensors is at least one
     # NCHW2NHWC operator and tensor for each channel-based input tensor
     # and a NHWC2NCHW operator and tensor for the output of each convolution
     # X, w1, w2, w3, w4 are channel-based inputs
     # c1, c2, c3, c4 are the outputs of convolutions
     # i.e. a total of 9.
     self.assertEqual(postTransformNumOperators,
                      preTransformNumOperators + 9,
                      "expected 9 additional operators")
     self.assertEqual(postTransformNumTensors, preTransformNumTensors + 9,
                      "expected 9 additional tensors")
     workspace.ResetWorkspace()
     for name, val in zip(all_inp_names, all_input):
         workspace.FeedBlob(name, val)
     workspace.RunNetOnce(new_netdef)
     postTransformC1 = workspace.FetchBlob("c1")
     postTransformC3 = workspace.FetchBlob("c3")
     postTransformOut = workspace.FetchBlob("out")
     np.testing.assert_almost_equal(postTransformC1, preTransformC1, 1)
     np.testing.assert_almost_equal(postTransformC3, preTransformC3, 1)
     np.testing.assert_almost_equal(postTransformOut, preTransformOut, 1)
Ejemplo n.º 2
0
    def test_transformer_FuseConvBN(self, size, input_channels, seed, order,
                                    epsilon):
        workspace.ResetWorkspace()
        net = core.Net("net")
        c = input_channels
        h = size
        w = size
        k = 3
        net.Conv(["X", "w", "b"], ["Y"],
                 stride=1,
                 pad=0,
                 kernel=k,
                 order=order)
        net.SpatialBN(
            ["Y", "scale", "bias", "mean", "var"],
            ["Y2"],
            is_test=True,
            order=order,
            epsilon=epsilon,
        )

        np.random.seed(seed)
        if order == "NCHW":
            tu.randBlobFloat32("X", 1, c, h, w)
            tu.randBlobFloat32("w", c, c, k, k)
        else:
            tu.randBlobFloat32("X", 1, h, w, c)
            tu.randBlobFloat32("w", c, k, k, c)
        tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)

        # This is necessary because 1/sqrt(var) is used and if var is too small
        # we get floating point artifacts that cause test failures
        tu.randBlobFloat32("var", c, offset=0.5)
        workspace.RunNetOnce(net)
        preTransformOutput = workspace.FetchBlob("Y2").flatten()
        workspace.FeedBlob("Y2", np.zeros((1, 1)))
        transformer.FuseConvBN(net)

        # Ensure fusion
        assert len(net.Proto().op) == 1
        workspace.RunNetOnce(net)
        postTransformOutput = workspace.FetchBlob("Y2").flatten()
        # Check that there is no numerical difference
        assert np.allclose(preTransformOutput,
                           postTransformOutput,
                           rtol=5e-02,
                           atol=1e-03)
    def test_transformer_FuseConv3DBN(
        self, size, input_channels, kt, kh, kw, seed, epsilon
    ):
        workspace.ResetWorkspace()
        net = core.Net("net")
        c = input_channels
        t = size
        h = size
        w = size
        net.Conv(
            ["X", "w", "b"],
            ["Y"],
            kernels=[kt, kh, kw],
        )
        net.SpatialBN(
            ["Y", "scale", "bias", "mean", "var"],
            ["Y2"],
            is_test=True,
            epsilon=epsilon,
        )

        np.random.seed(seed)
        tu.randBlobFloat32("X", 1, c, t, h, w)
        tu.randBlobFloat32("w", c, c, kt, kh, kw)
        tu.randBlobsFloat32(["b", "scale", "bias", "mean"], c)
        # This is necessary because 1/sqrt(var) is used and if var is too small
        # we get floating point artifacts that cause test failures
        tu.randBlobFloat32("var", c, offset=0.5)
        workspace.RunNetOnce(net)
        preTransformOutput = workspace.FetchBlob("Y2").flatten()
        workspace.FeedBlob("Y2", np.zeros((1, 1)))
        transformer.FuseConvBN(net)

        # Ensure fusion
        assert tu.numOps(net) == 1
        workspace.RunNetOnce(net)
        postTransformOutput = workspace.FetchBlob("Y2").flatten()
        # Check that there is no numerical difference
        assert np.allclose(
            preTransformOutput,
            postTransformOutput,
            rtol=1e-02,
            atol=1e-04
        )