Example #1
0
def test_slicing_sequential(idx):
    input_shape = (None, 3, 256, 256)

    a = nnt.Sequential(input_shape=input_shape)
    a.conv1 = nnt.Conv2d(a.output_shape, 64, 3)
    a.conv2 = nnt.Conv2d(a.output_shape, 128, 3)
    a.conv3 = nnt.Conv2d(a.output_shape, 256, 3)
    a.conv4 = nnt.Conv2d(a.output_shape, 512, 3)

    b = a[idx]
    start = 0 if idx.start is None else idx.start
    assert b.input_shape == a[start].input_shape

    class Foo(nnt.Sequential):
        def __init__(self, input_shape):
            super().__init__(input_shape=input_shape)

            self.conv1 = nnt.Conv2d(self.output_shape, 64, 3)
            self.conv2 = nnt.Conv2d(self.output_shape, 128, 3)
            self.conv3 = nnt.Conv2d(self.output_shape, 256, 3)
            self.conv4 = nnt.Conv2d(self.output_shape, 512, 3)

    foo = Foo(input_shape)
    b = foo[idx]
    start = 0 if idx.start is None else idx.start
    assert isinstance(b, nnt.Sequential)
    assert b.input_shape == a[start].input_shape
Example #2
0
    def __init__(self,
                 input_shape,
                 out_features,
                 num_instances=None,
                 bias=True,
                 activation=None):
        super().__init__(input_shape)
        self.out_features = out_features
        self.activation = nnt.utils.function(activation)

        self.main = nnt.Sequential(input_shape=input_shape)
        self.main.add_module(
            'conv_r_1',
            nnt.FC(self.main.output_shape,
                   out_features,
                   bias=bias,
                   activation=activation))
        self.main.add_module(
            'convr_2',
            GraphXConv(self.main.output_shape, out_features, num_instances,
                       bias, None))
        if num_instances is None:
            self.res = (lambda x: x) if (out_features == input_shape[-1]) \
                else nnt.FC(input_shape, out_features, bias=bias, activation=None)
        else:
            self.res = GraphXConv(input_shape,
                                  out_features,
                                  num_instances,
                                  activation=None)
Example #3
0
def test_spectral_norm(device):
    from copy import deepcopy
    import torch.nn as nn

    seed = 48931
    input = T.rand(10, 3, 5, 5).to(device)

    net = nnt.Sequential(
        nnt.Sequential(nnt.Conv2d(3, 16, 3), nnt.Conv2d(16, 32, 3)),
        nnt.Sequential(
            nnt.Conv2d(32, 64, 3),
            nnt.Conv2d(64, 128, 3),
        ), nnt.BatchNorm2d(128), nnt.GroupNorm(128, 4),
        nnt.LayerNorm((None, 128, 5, 5)), nnt.GlobalAvgPool2D(),
        nnt.FC(128, 1)).to(device)

    net_pt_sn = deepcopy(net)
    T.manual_seed(seed)
    if cuda_available:
        T.cuda.manual_seed_all(seed)

    net_pt_sn[0][0] = nn.utils.spectral_norm(net_pt_sn[0][0])
    net_pt_sn[0][1] = nn.utils.spectral_norm(net_pt_sn[0][1])
    net_pt_sn[1][0] = nn.utils.spectral_norm(net_pt_sn[1][0])
    net_pt_sn[1][1] = nn.utils.spectral_norm(net_pt_sn[1][1])
    net_pt_sn[6] = nn.utils.spectral_norm(net_pt_sn[6])

    T.manual_seed(seed)
    if cuda_available:
        T.cuda.manual_seed_all(seed)

    net_nnt_sn = nnt.spectral_norm(net)

    net_pt_sn(input)
    net_nnt_sn(input)

    assert not hasattr(net_nnt_sn[2], 'weight_u')
    assert not hasattr(net_nnt_sn[3], 'weight_u')
    assert not hasattr(net_nnt_sn[4], 'weight_u')

    testing.assert_allclose(net_pt_sn[0][0].weight, net_nnt_sn[0][0].weight)
    testing.assert_allclose(net_pt_sn[0][1].weight, net_nnt_sn[0][1].weight)
    testing.assert_allclose(net_pt_sn[1][0].weight, net_nnt_sn[1][0].weight)
    testing.assert_allclose(net_pt_sn[1][1].weight, net_nnt_sn[1][1].weight)
    testing.assert_allclose(net_pt_sn[6].weight, net_nnt_sn[6].weight)
Example #4
0
 def __init__(self, input_shape, out_features, bias=True, activation=None):
     super().__init__(input_shape,
                      out_features,
                      bias=bias,
                      activation=activation)
     self.main = nnt.Sequential(input_shape=input_shape)
     self.main.add_module(
         'fc1',
         nnt.FC(self.main.output_shape, out_features,
                activation=activation))
     self.main.add_module(
         'fc2', nnt.FC(self.main.output_shape,
                       out_features,
                       activation=None))