示例#1
0
def test_lambda(device):
    shape = (3, 10, 5, 5)
    a = T.rand(*shape).to(device)

    def foo1(x, y):
        return x**y

    sqr = nnt.Lambda(foo1, y=2., input_shape=shape)
    expected = a**2.
    testing.assert_allclose(sqr(a), expected)
    testing.assert_allclose(sqr.output_shape, expected.shape)

    def foo2(x, fr, to):
        return x[:, fr:to]

    fr = 3
    to = 7
    a = T.rand(*shape).to(device)
    if cuda_available:
        a = a.cuda()

    slice = nnt.Lambda(foo2, fr=fr, to=to, input_shape=shape)
    expected = a[:, fr:to]
    testing.assert_allclose(slice(a), expected)
    testing.assert_allclose(slice.output_shape, expected.shape)
示例#2
0
def test_sum(device):
    shape = (3, 2, 4, 4)
    out_channels = 5

    a = T.rand(*shape).to(device)
    b = T.rand(*shape).to(device)

    sum = nnt.Sum(
        nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape))
    expected = (a + 1.) + (a * 2.)
    testing.assert_allclose(sum(a), expected)
    testing.assert_allclose(sum.output_shape, expected.shape)

    sum = nnt.Sum(
        a, b,
        nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape))
    expected = a + b + (a * 2.)
    testing.assert_allclose(sum(a), expected)
    testing.assert_allclose(sum.output_shape, expected.shape)

    con_sum = nnt.ConcurrentSum(
        nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape))
    expected = (a + 1.) + (b * 2.)
    testing.assert_allclose(con_sum(a, b), expected)
    testing.assert_allclose(con_sum.output_shape, expected.shape)

    con_sum = nnt.ConcurrentSum(
        a, b,
        nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape))
    expected = a + b + (a + 1.) + (b * 2.)
    testing.assert_allclose(con_sum(a, b), expected)
    testing.assert_allclose(con_sum.output_shape, expected.shape)

    seq_sum = nnt.SequentialSum(
        nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape))
    expected = (a + 1.) + (a + 1.) * 2.
    testing.assert_allclose(seq_sum(a), expected)
    testing.assert_allclose(seq_sum.output_shape, expected.shape)

    seq_sum = nnt.SequentialSum(
        a, nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape))
    expected = a + (a + 1.) + (a + 1.) * 2.
    testing.assert_allclose(seq_sum(a), expected)
    testing.assert_allclose(seq_sum.output_shape, expected.shape)

    m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device)
    m2 = nnt.Conv2d(b.shape, out_channels, 3).to(device)
    con_sum = nnt.ConcurrentSum(m1, m2)
    expected = m1(a) + m2(b)
    testing.assert_allclose(con_sum(a, b), expected)
    testing.assert_allclose(con_sum.output_shape, expected.shape)

    m1 = nnt.Conv2d(a.shape[1], a.shape[1], 3).to(device)
    m2 = nnt.Conv2d(a.shape[1], a.shape[1], 3).to(device)
    seq_sum = nnt.SequentialSum(a, m1, m2, b)
    expected = a + m1(a) + m2(m1(a)) + b
    testing.assert_allclose(seq_sum(a), expected)
    testing.assert_allclose(seq_sum.output_shape, expected.shape)
示例#3
0
def test_cat(device):
    shape1 = (3, 2, 4, 4)
    shape2 = (3, 5, 4, 4)
    out_channels = 5

    a = T.rand(*shape1).to(device)
    b = T.rand(*shape2).to(device)

    cat = nnt.Cat(
        1, nnt.Lambda(lambda x: x + 1.,
                      output_shape=shape1,
                      input_shape=shape1),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1))
    expected = T.cat((a + 1., a * 2.), 1)
    testing.assert_allclose(cat(a), expected)
    testing.assert_allclose(cat.output_shape, expected.shape)

    cat = nnt.Cat(
        1, a, b,
        nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1))
    expected = T.cat((a, b, a * 2.), 1)
    testing.assert_allclose(cat(a), expected)
    testing.assert_allclose(cat.output_shape, expected.shape)

    con_cat = nnt.ConcurrentCat(
        1, nnt.Lambda(lambda x: x + 1.,
                      output_shape=shape1,
                      input_shape=shape1),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape2, input_shape=shape2))
    expected = T.cat((a + 1., b * 2.), 1)
    testing.assert_allclose(con_cat(a, b), expected)
    testing.assert_allclose(con_cat.output_shape, expected.shape)

    con_cat = nnt.ConcurrentCat(
        1, a, b,
        nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape2, input_shape=shape2))
    expected = T.cat((a, b, a + 1., b * 2.), 1)
    testing.assert_allclose(con_cat(a, b), expected)
    testing.assert_allclose(con_cat.output_shape, expected.shape)

    seq_cat = nnt.SequentialCat(
        2, nnt.Lambda(lambda x: x + 1.,
                      output_shape=shape1,
                      input_shape=shape1),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1))
    expected = T.cat((a + 1., (a + 1.) * 2.), 2)
    testing.assert_allclose(seq_cat(a), expected)
    testing.assert_allclose(seq_cat.output_shape, expected.shape)

    seq_cat = nnt.SequentialCat(
        2, a,
        nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1),
        nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1))
    expected = T.cat((a, a + 1., (a + 1.) * 2.), 2)
    testing.assert_allclose(seq_cat(a), expected)
    testing.assert_allclose(seq_cat.output_shape, expected.shape)

    m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device)
    m2 = nnt.Conv2d(b.shape, out_channels, 3).to(device)
    con_cat = nnt.ConcurrentCat(1, a, m1, b, m2)
    expected = T.cat((a, m1(a), b, m2(b)), 1)
    testing.assert_allclose(con_cat(a, b), expected)
    testing.assert_allclose(con_cat.output_shape, expected.shape)

    m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device)
    m2 = nnt.Conv2d(out_channels, out_channels, 3).to(device)
    seq_cat = nnt.SequentialCat(1, a, m1, m2, b)
    expected = T.cat((a, m1(a), m2(m1(a)), b), 1)
    testing.assert_allclose(seq_cat(a), expected)
    testing.assert_allclose(seq_cat.output_shape, expected.shape)