def __init__(self, input_shape): super().__init__(input_shape=input_shape) self.conv1 = nnt.Conv2d(self.output_shape, 64, 3) self.conv2 = nnt.Conv2d(self.output_shape, 128, 3) self.conv3 = nnt.Conv2d(self.output_shape, 256, 3) self.conv4 = nnt.Conv2d(self.output_shape, 512, 3)
def test_slicing_sequential(idx): input_shape = (None, 3, 256, 256) a = nnt.Sequential(input_shape=input_shape) a.conv1 = nnt.Conv2d(a.output_shape, 64, 3) a.conv2 = nnt.Conv2d(a.output_shape, 128, 3) a.conv3 = nnt.Conv2d(a.output_shape, 256, 3) a.conv4 = nnt.Conv2d(a.output_shape, 512, 3) b = a[idx] start = 0 if idx.start is None else idx.start assert b.input_shape == a[start].input_shape class Foo(nnt.Sequential): def __init__(self, input_shape): super().__init__(input_shape=input_shape) self.conv1 = nnt.Conv2d(self.output_shape, 64, 3) self.conv2 = nnt.Conv2d(self.output_shape, 128, 3) self.conv3 = nnt.Conv2d(self.output_shape, 256, 3) self.conv4 = nnt.Conv2d(self.output_shape, 512, 3) foo = Foo(input_shape) b = foo[idx] start = 0 if idx.start is None else idx.start assert isinstance(b, nnt.Sequential) assert b.input_shape == a[start].input_shape
def test_spectral_norm(device): from copy import deepcopy import torch.nn as nn seed = 48931 input = T.rand(10, 3, 5, 5).to(device) net = nnt.Sequential( nnt.Sequential(nnt.Conv2d(3, 16, 3), nnt.Conv2d(16, 32, 3)), nnt.Sequential( nnt.Conv2d(32, 64, 3), nnt.Conv2d(64, 128, 3), ), nnt.BatchNorm2d(128), nnt.GroupNorm(128, 4), nnt.LayerNorm((None, 128, 5, 5)), nnt.GlobalAvgPool2D(), nnt.FC(128, 1)).to(device) net_pt_sn = deepcopy(net) T.manual_seed(seed) if cuda_available: T.cuda.manual_seed_all(seed) net_pt_sn[0][0] = nn.utils.spectral_norm(net_pt_sn[0][0]) net_pt_sn[0][1] = nn.utils.spectral_norm(net_pt_sn[0][1]) net_pt_sn[1][0] = nn.utils.spectral_norm(net_pt_sn[1][0]) net_pt_sn[1][1] = nn.utils.spectral_norm(net_pt_sn[1][1]) net_pt_sn[6] = nn.utils.spectral_norm(net_pt_sn[6]) T.manual_seed(seed) if cuda_available: T.cuda.manual_seed_all(seed) net_nnt_sn = nnt.spectral_norm(net) net_pt_sn(input) net_nnt_sn(input) assert not hasattr(net_nnt_sn[2], 'weight_u') assert not hasattr(net_nnt_sn[3], 'weight_u') assert not hasattr(net_nnt_sn[4], 'weight_u') testing.assert_allclose(net_pt_sn[0][0].weight, net_nnt_sn[0][0].weight) testing.assert_allclose(net_pt_sn[0][1].weight, net_nnt_sn[0][1].weight) testing.assert_allclose(net_pt_sn[1][0].weight, net_nnt_sn[1][0].weight) testing.assert_allclose(net_pt_sn[1][1].weight, net_nnt_sn[1][1].weight) testing.assert_allclose(net_pt_sn[6].weight, net_nnt_sn[6].weight)
def test_track(device): shape = (2, 3, 5, 5) a = T.rand(*shape).to(device) conv1 = nnt.track('op', nnt.Conv2d(shape, 4, 3), 'all').to(device) conv2 = nnt.Conv2d(conv1.output_shape, 5, 3).to(device) intermediate = conv1(a) output = nnt.track('conv2_output', conv2(intermediate), 'all').to(device) loss = T.sum(output**2) loss.backward(retain_graph=True) d_inter = T.autograd.grad(loss, intermediate, retain_graph=True) d_out = T.autograd.grad(loss, output) tracked = nnt.eval_tracked_variables() testing.assert_allclose(tracked['conv2_output'], nnt.utils.to_numpy(output)) testing.assert_allclose(np.stack(tracked['grad_conv2_output']), nnt.utils.to_numpy(d_out[0])) testing.assert_allclose(tracked['op'], nnt.utils.to_numpy(intermediate)) for d_inter_, tracked_d_inter_ in zip(d_inter, tracked['grad_op_output']): testing.assert_allclose(tracked_d_inter_, nnt.utils.to_numpy(d_inter_))
def test_adain(device, dim1, dim2): def _expected(module1, module2, input1, input2, dim1, dim2): output1 = module1(input1) output2 = module2(input2) mean1, std1 = T.mean( output1, dim1, keepdim=True), T.sqrt(T.var(output1, dim1, keepdim=True) + 1e-8) mean2, std2 = T.mean( output2, dim2, keepdim=True), T.sqrt(T.var(output2, dim2, keepdim=True) + 1e-8) return std2 * (output1 - mean1) / std1 + mean2 shape = (2, 3, 4, 5) a = T.rand(*shape).to(device) b = T.rand(*shape).to(device) module1 = nnt.Conv2d(shape, 6, 3).to(device) module2 = nnt.Conv2d(shape, 6, 3).to(device) adain = nnt.AdaIN(module1, dim1).to(device) mi_adain = nnt.MultiInputAdaIN(module1, module2, dim1=dim1, dim2=dim2).to(device) mm_adain = nnt.MultiModuleAdaIN(module1, module2, dim1=dim1, dim2=dim2).to(device) actual_adain = adain(a, b) expected_adain = _expected(module1, module1, a, b, dim1, dim1) testing.assert_allclose(actual_adain, expected_adain) testing.assert_allclose(adain.output_shape, expected_adain.shape) actual_mi_adain = mi_adain(a, b) expected_mi_adain = _expected(module1, module2, a, b, dim1, dim2) testing.assert_allclose(actual_mi_adain, expected_mi_adain) testing.assert_allclose(mi_adain.output_shape, expected_mi_adain.shape) actual_mm_adain = mm_adain(a) expected_mm_adain = _expected(module1, module2, a, a, dim1, dim2) testing.assert_allclose(actual_mm_adain, expected_mm_adain) testing.assert_allclose(mm_adain.output_shape, expected_mm_adain.shape)
def test_conv2d_layer(device, filter_size, stride, padding, dilation, output_shape): shape = (2, 3, 10, 10) n_filters = 5 conv_nnt = nnt.Conv2d(shape, n_filters, filter_size, stride, padding, dilation).to(device) conv_pt = T.nn.Conv2d(shape[1], n_filters, filter_size, stride, padding, dilation).to(device) sanity_check(conv_nnt, conv_pt, shape, device=device) input = T.arange(np.prod(shape)).view(*shape).float().to(device) out_nnt = conv_nnt(input) out_pt = conv_pt(input) testing.assert_allclose(conv_nnt.output_shape, out_pt.shape) testing.assert_allclose(out_nnt.shape, output_shape) testing.assert_allclose(conv_nnt.output_shape, output_shape)
def test_conv2d_layer(device, filter_size, stride, padding, dilation): shape_sym = ('b', 3, 'h', 'w') shape = (2, 3, 10, 10) n_filters = 5 conv_nnt = nnt.Conv2d(shape_sym, n_filters, filter_size, stride, padding, dilation).to(device) conv_pt = T.nn.Conv2d(shape[1], n_filters, filter_size, stride, padding, dilation).to(device) sanity_check(conv_nnt, conv_pt, shape, device=device) input = T.arange(np.prod(shape)).view(*shape).float().to(device) out_pt = conv_pt(input) h = conv_nnt.output_shape[2].subs(conv_nnt.input_shape[2], shape[2]) w = conv_nnt.output_shape[3].subs(conv_nnt.input_shape[3], shape[3]) assert h == out_pt.shape[2] assert w == out_pt.shape[3]
def test_depthwise_sepconv(device, depth_mul): shape = (2, 3, 5, 5) n_filters = 4 filter_size = 3 a = T.arange(np.prod(shape)).view(*shape).float().to(device) conv_dw = nnt.DepthwiseSepConv2D(shape, n_filters, 3, depth_mul=depth_mul, bias=False).to(device) conv = nnt.Conv2d(shape, n_filters, filter_size, bias=False).to(device) weight = T.stack([ F.conv2d(conv_dw.depthwise.weight[i:i + 1].transpose(0, 1), conv_dw.pointwise.weight[:, i:i + 1]).squeeze() for i in range(shape[1] * depth_mul) ]) weight = weight.view(shape[1], depth_mul, n_filters, 3, 3) weight = weight.sum(1).transpose(0, 1) conv.weight.data = weight testing.assert_allclose(conv_dw(a), conv(a))
def test_sum(device): shape = (3, 2, 4, 4) out_channels = 5 a = T.rand(*shape).to(device) b = T.rand(*shape).to(device) sum = nnt.Sum( nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = (a + 1.) + (a * 2.) testing.assert_allclose(sum(a), expected) testing.assert_allclose(sum.output_shape, expected.shape) sum = nnt.Sum( a, b, nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = a + b + (a * 2.) testing.assert_allclose(sum(a), expected) testing.assert_allclose(sum.output_shape, expected.shape) con_sum = nnt.ConcurrentSum( nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = (a + 1.) + (b * 2.) testing.assert_allclose(con_sum(a, b), expected) testing.assert_allclose(con_sum.output_shape, expected.shape) con_sum = nnt.ConcurrentSum( a, b, nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = a + b + (a + 1.) + (b * 2.) testing.assert_allclose(con_sum(a, b), expected) testing.assert_allclose(con_sum.output_shape, expected.shape) seq_sum = nnt.SequentialSum( nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = (a + 1.) + (a + 1.) * 2. testing.assert_allclose(seq_sum(a), expected) testing.assert_allclose(seq_sum.output_shape, expected.shape) seq_sum = nnt.SequentialSum( a, nnt.Lambda(lambda x: x + 1., output_shape=shape, input_shape=shape), nnt.Lambda(lambda x: 2. * x, output_shape=shape, input_shape=shape)) expected = a + (a + 1.) + (a + 1.) * 2. testing.assert_allclose(seq_sum(a), expected) testing.assert_allclose(seq_sum.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device) m2 = nnt.Conv2d(b.shape, out_channels, 3).to(device) con_sum = nnt.ConcurrentSum(m1, m2) expected = m1(a) + m2(b) testing.assert_allclose(con_sum(a, b), expected) testing.assert_allclose(con_sum.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape[1], a.shape[1], 3).to(device) m2 = nnt.Conv2d(a.shape[1], a.shape[1], 3).to(device) seq_sum = nnt.SequentialSum(a, m1, m2, b) expected = a + m1(a) + m2(m1(a)) + b testing.assert_allclose(seq_sum(a), expected) testing.assert_allclose(seq_sum.output_shape, expected.shape)
def test_cat(device): shape1 = (3, 2, 4, 4) shape2 = (3, 5, 4, 4) out_channels = 5 a = T.rand(*shape1).to(device) b = T.rand(*shape2).to(device) cat = nnt.Cat( 1, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a + 1., a * 2.), 1) testing.assert_allclose(cat(a), expected) testing.assert_allclose(cat.output_shape, expected.shape) cat = nnt.Cat( 1, a, b, nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a, b, a * 2.), 1) testing.assert_allclose(cat(a), expected) testing.assert_allclose(cat.output_shape, expected.shape) con_cat = nnt.ConcurrentCat( 1, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape2, input_shape=shape2)) expected = T.cat((a + 1., b * 2.), 1) testing.assert_allclose(con_cat(a, b), expected) testing.assert_allclose(con_cat.output_shape, expected.shape) con_cat = nnt.ConcurrentCat( 1, a, b, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape2, input_shape=shape2)) expected = T.cat((a, b, a + 1., b * 2.), 1) testing.assert_allclose(con_cat(a, b), expected) testing.assert_allclose(con_cat.output_shape, expected.shape) seq_cat = nnt.SequentialCat( 2, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a + 1., (a + 1.) * 2.), 2) testing.assert_allclose(seq_cat(a), expected) testing.assert_allclose(seq_cat.output_shape, expected.shape) seq_cat = nnt.SequentialCat( 2, a, nnt.Lambda(lambda x: x + 1., output_shape=shape1, input_shape=shape1), nnt.Lambda(lambda x: 2. * x, output_shape=shape1, input_shape=shape1)) expected = T.cat((a, a + 1., (a + 1.) * 2.), 2) testing.assert_allclose(seq_cat(a), expected) testing.assert_allclose(seq_cat.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device) m2 = nnt.Conv2d(b.shape, out_channels, 3).to(device) con_cat = nnt.ConcurrentCat(1, a, m1, b, m2) expected = T.cat((a, m1(a), b, m2(b)), 1) testing.assert_allclose(con_cat(a, b), expected) testing.assert_allclose(con_cat.output_shape, expected.shape) m1 = nnt.Conv2d(a.shape, out_channels, 3).to(device) m2 = nnt.Conv2d(out_channels, out_channels, 3).to(device) seq_cat = nnt.SequentialCat(1, a, m1, m2, b) expected = T.cat((a, m1(a), m2(m1(a)), b), 1) testing.assert_allclose(seq_cat(a), expected) testing.assert_allclose(seq_cat.output_shape, expected.shape)