Example #1
0
def test_pixelshuffle2d():
    nchan = 2
    up_x = 2
    up_y = 3
    nx = 2
    ny = 3
    shape_before = (1, nchan * up_x * up_y, nx, ny)
    shape_after = (1, nchan, nx * up_x, ny * up_y)
    layer = PixelShuffle2D((up_x, up_y))
    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
    y = layer(x)
    assert y.shape == shape_after
    # - Channels are reshaped to form 2x3 blocks
    # - Within each block, the increment is `nx * ny` when increasing the column
    #   index by 1
    # - Increasing the block index adds an offset of 1
    # - Increasing the channel index adds an offset of `nx * up_x * ny * up_y`
    assert_allclose(
        y,
        [[[[ 0,  6, 12,  1,  7, 13,  2,  8, 14],
           [18, 24, 30, 19, 25, 31, 20, 26, 32],
           [ 3,  9, 15,  4, 10, 16,  5, 11, 17],
           [21, 27, 33, 22, 28, 34, 23, 29, 35]],

          [[36, 42, 48, 37, 43, 49, 38, 44, 50],
           [54, 60, 66, 55, 61, 67, 56, 62, 68],
           [39, 45, 51, 40, 46, 52, 41, 47, 53],
           [57, 63, 69, 58, 64, 70, 59, 65, 71]]]]
    )
def test_pixelshuffle3d():
    nchan = 1
    up_x = 2
    up_y = 1
    up_z = 2
    nx = 2
    ny = 3
    nz = 4
    shape_before = (1, nchan * up_x * up_y * up_z, nx, ny, nz)
    shape_after = (1, nchan, nx * up_x, ny * up_y, nz * up_z)
    layer = PixelShuffle3D((up_x, up_y, up_z))
    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
    y = layer(x)
    assert y.shape == shape_after
    # - Channels are reshaped to form 2x1x2 blocks
    # - Within each block, the increment is `nx * ny * nz` when increasing the
    #   column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]]
    # - Increasing the block index adds an offset of 1
    assert_allclose(
        y,
        [[[[[0, 24, 1, 25, 2, 26, 3, 27], [4, 28, 5, 29, 6, 30, 7, 31],
            [8, 32, 9, 33, 10, 34, 11, 35]],
           [[48, 72, 49, 73, 50, 74, 51, 75], [52, 76, 53, 77, 54, 78, 55, 79],
            [56, 80, 57, 81, 58, 82, 59, 83]],
           [[12, 36, 13, 37, 14, 38, 15, 39], [16, 40, 17, 41, 18, 42, 19, 43],
            [20, 44, 21, 45, 22, 46, 23, 47]],
           [[60, 84, 61, 85, 62, 86, 63, 87], [64, 88, 65, 89, 66, 90, 67, 91],
            [68, 92, 69, 93, 70, 94, 71, 95]]]]])
Example #3
0
def test_optimize_layout(np_shape_array, amp_init, model, ndim):
    m = model(ndim)
    m.initialize(ctx=mx.gpu())
    m.hybridize()
    x = mx.np.random.uniform(low=0,
                             high=10,
                             size=(32, 2, 17, 15, 12)[:ndim + 2],
                             ctx=mx.gpu())
    m(x)
    param_init = {k: v.data().copy() for k, v in m.collect_params().items()}
    for v in m.collect_params().values():
        v.data().attach_grad()
    with mx.autograd.record():
        y = m(x)
    y.backward()
    with optimize_layout():
        m2 = model(ndim)
        m2.initialize(ctx=mx.gpu())
        m2.load_dict(param_init, device=mx.gpu())
        m2.hybridize()
        for v in m2.collect_params().values():
            v.data().attach_grad()
        with mx.autograd.record():
            y2 = m2(x)
        y2.backward()
    rtol = 1e-2
    atol = 1e-2
    assert_allclose(y2, y, rtol=rtol, atol=atol)
    for k, v in m.collect_params().items():
        if v.grad_req == 'null':
            continue
        assert_allclose(m2.collect_params()[k].grad(),
                        v.grad(),
                        rtol=rtol,
                        atol=atol)
Example #4
0
def test_scalar():
    class Foo(HybridBlock):
        def forward(self, x):
            return x * x * 2

    foo = Foo()
    foo.hybridize()
    foo.initialize()
    out = foo(mx.np.array(1.0))
    assert_allclose(out.asnumpy(), np.array(2.0))
def test_pixelshuffle1d():
    nchan = 2
    up_x = 2
    nx = 3
    shape_before = (1, nchan * up_x, nx)
    shape_after = (1, nchan, nx * up_x)
    layer = PixelShuffle1D(up_x)
    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
    y = layer(x)
    assert y.shape == shape_after
    assert_allclose(y, [[[0, 3, 1, 4, 2, 5], [6, 9, 7, 10, 8, 11]]])