def test_lambda(device): shape = (3, 10, 5, 5) a = T.rand(*shape).to(device) def foo1(x, y): return x**y sqr = nnt.Lambda(foo1, y=2., input_shape=shape) expected = a**2. testing.assert_allclose(sqr(a), expected) testing.assert_allclose(sqr.output_shape, expected.shape) def foo2(x, fr, to): return x[:, fr:to] fr = 3 to = 7 a = T.rand(*shape).to(device) if cuda_available: a = a.cuda() slice = nnt.Lambda(foo2, fr=fr, to=to, input_shape=shape) expected = a[:, fr:to] testing.assert_allclose(slice(a), expected) testing.assert_allclose(slice.output_shape, expected.shape)
def test_sum(device): shape = (3, 2, 4, 4) out_channels = 5 a = T.rand(*shape).to(device) b = T.rand(*shape).to(device) sum = nnt.Sum( nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = (a + 1.) + (a * 2.) testing.assert_allclose(sum(a), expected) testing.assert_allclose(sum.output_shape, expected.shape) sum = nnt.Sum( a, b, nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = a + b + (a * 2.) testing.assert_allclose(sum(a), expected) testing.assert_allclose(sum.output_shape, expected.shape) con_sum = nnt.ConcurrentSum( nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = (a + 1.) + (b * 2.) testing.assert_allclose(con_sum(a, b), expected) testing.assert_allclose(con_sum.output_shape, expected.shape) con_sum = nnt.ConcurrentSum( a, b, nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = a + b + (a + 1.) + (b * 2.) testing.assert_allclose(con_sum(a, b), expected) testing.assert_allclose(con_sum.output_shape, expected.shape) seq_sum = nnt.SequentialSum( nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = (a + 1.) + (a + 1.) * 2. testing.assert_allclose(seq_sum(a), expected) testing.assert_allclose(seq_sum.output_shape, expected.shape) seq_sum = nnt.SequentialSum( a, nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = a + (a + 1.) + (a + 1.) * 2. testing.assert_allclose(seq_sum(a), expected) testing.assert_allclose(seq_sum.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device) m2 = nnt.Conv2d(b.shape, out_channels, 3).to(device) con_sum = nnt.ConcurrentSum(m1, m2) expected = m1(a) + m2(b) testing.assert_allclose(con_sum(a, b), expected) testing.assert_allclose(con_sum.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape[1], a.shape[1], 3).to(device) m2 = nnt.Conv2d(a.shape[1], a.shape[1], 3).to(device) seq_sum = nnt.SequentialSum(a, m1, m2, b) expected = a + m1(a) + m2(m1(a)) + b testing.assert_allclose(seq_sum(a), expected) testing.assert_allclose(seq_sum.output_shape, expected.shape)
def test_cat(device): shape1 = (3, 2, 4, 4) shape2 = (3, 5, 4, 4) out_channels = 5 a = T.rand(*shape1).to(device) b = T.rand(*shape2).to(device) cat = nnt.Cat( 1, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a + 1., a * 2.), 1) testing.assert_allclose(cat(a), expected) testing.assert_allclose(cat.output_shape, expected.shape) cat = nnt.Cat( 1, a, b, nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a, b, a * 2.), 1) testing.assert_allclose(cat(a), expected) testing.assert_allclose(cat.output_shape, expected.shape) con_cat = nnt.ConcurrentCat( 1, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape2, input_shape=shape2)) expected = T.cat((a + 1., b * 2.), 1) testing.assert_allclose(con_cat(a, b), expected) testing.assert_allclose(con_cat.output_shape, expected.shape) con_cat = nnt.ConcurrentCat( 1, a, b, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape2, input_shape=shape2)) expected = T.cat((a, b, a + 1., b * 2.), 1) testing.assert_allclose(con_cat(a, b), expected) testing.assert_allclose(con_cat.output_shape, expected.shape) seq_cat = nnt.SequentialCat( 2, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a + 1., (a + 1.) * 2.), 2) testing.assert_allclose(seq_cat(a), expected) testing.assert_allclose(seq_cat.output_shape, expected.shape) seq_cat = nnt.SequentialCat( 2, a, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a, a + 1., (a + 1.) * 2.), 2) testing.assert_allclose(seq_cat(a), expected) testing.assert_allclose(seq_cat.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device) m2 = nnt.Conv2d(b.shape, out_channels, 3).to(device) con_cat = nnt.ConcurrentCat(1, a, m1, b, m2) expected = T.cat((a, m1(a), b, m2(b)), 1) testing.assert_allclose(con_cat(a, b), expected) testing.assert_allclose(con_cat.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device) m2 = nnt.Conv2d(out_channels, out_channels, 3).to(device) seq_cat = nnt.SequentialCat(1, a, m1, m2, b) expected = T.cat((a, m1(a), m2(m1(a)), b), 1) testing.assert_allclose(seq_cat(a), expected) testing.assert_allclose(seq_cat.output_shape, expected.shape)