Example #1
0
    def __init__(self,
                 input_shape,
                 out_features,
                 num_instances=None,
                 bias=True,
                 activation=None):
        super().__init__(input_shape)
        self.out_features = out_features
        self.activation = nnt.utils.function(activation)

        self.main = nnt.Sequential(input_shape=input_shape)
        self.main.add_module(
            'conv_r_1',
            nnt.FC(self.main.output_shape,
                   out_features,
                   bias=bias,
                   activation=activation))
        self.main.add_module(
            'convr_2',
            GraphXConv(self.main.output_shape, out_features, num_instances,
                       bias, None))
        if num_instances is None:
            self.res = (lambda x: x) if (out_features == input_shape[-1]) \
                else nnt.FC(input_shape, out_features, bias=bias, activation=None)
        else:
            self.res = GraphXConv(input_shape,
                                  out_features,
                                  num_instances,
                                  activation=None)
Example #2
0
    def __init__(self,
                 input_shape,
                 out_features,
                 num_instances=None,
                 bias=True,
                 activation=None,
                 weights_init=None,
                 bias_init=None):
        super().__init__(input_shape=input_shape)
        self.out_features = out_features
        self.num_instances = num_instances if num_instances else input_shape[1]
        self.activation = activation
        pattern = list(range(len(input_shape)))
        pattern[-1], pattern[-2] = pattern[-2], pattern[-1]

        self.add_module('dimshuffle1',
                        nnt.DimShuffle(pattern, input_shape=self.output_shape))
        self.add_module(
            'conv_l',
            nnt.FC(self.output_shape,
                   num_instances,
                   bias=bias,
                   activation=None,
                   weights_init=weights_init,
                   bias_init=bias_init))
        self.add_module('dimshuffle2',
                        nnt.DimShuffle(pattern, input_shape=self.output_shape))
        self.add_module(
            'conv_r',
            nnt.FC(self.output_shape,
                   out_features,
                   bias=bias,
                   activation=activation))
Example #3
0
def test_fc_layer(device, shape):
    out_features = 4

    # test constructors
    fc_nnt = nnt.FC(shape, out_features)
    fc_nnt = nnt.FC((2, 3), out_features)
    fc_pt = T.nn.Linear(shape[1] if isinstance(shape, tuple) else shape,
                        out_features)
    sanity_check(fc_nnt, fc_pt, shape=(2, 3), device=device)
Example #4
0
 def __init__(self, input_shape, out_features, bias=True, activation=None):
     super().__init__(input_shape,
                      out_features,
                      bias=bias,
                      activation=activation)
     self.main = nnt.Sequential(input_shape=input_shape)
     self.main.add_module(
         'fc1',
         nnt.FC(self.main.output_shape, out_features,
                activation=activation))
     self.main.add_module(
         'fc2', nnt.FC(self.main.output_shape,
                       out_features,
                       activation=None))
Example #5
0
    def __init__(self,
                 input_shape,
                 out_features,
                 num_instances=None,
                 rank=None,
                 bias=True,
                 activation=None,
                 weights_init=None,
                 bias_init=None):
        super().__init__(input_shape=input_shape)
        self.out_features = out_features
        self.num_instances = num_instances if num_instances else input_shape[1]
        self.rank = rank if rank is not None else self.num_instances // 2
        assert self.rank < self.num_instances, 'rank should be smaller than num_instances'

        self.activation = activation
        pattern = list(range(len(input_shape)))
        pattern[-1], pattern[-2] = pattern[-2], pattern[-1]

        self.add_module('dimshuffle1',
                        nnt.DimShuffle(pattern, input_shape=self.output_shape))
        self.add_module(
            'conv_l1',
            nnt.FC(self.output_shape,
                   self.rank,
                   bias=False,
                   activation=None,
                   weights_init=weights_init))
        self.add_module(
            'conv_l2',
            nnt.FC(self.output_shape,
                   self.num_instances,
                   bias=bias,
                   activation=None,
                   weights_init=weights_init,
                   bias_init=bias_init))
        self.add_module('dimshuffle2',
                        nnt.DimShuffle(pattern, input_shape=self.output_shape))
        self.add_module(
            'conv_r',
            nnt.FC(self.output_shape,
                   out_features,
                   bias=bias,
                   activation=activation))
Example #6
0
def test_spectral_norm(device):
    from copy import deepcopy
    import torch.nn as nn

    seed = 48931
    input = T.rand(10, 3, 5, 5).to(device)

    net = nnt.Sequential(
        nnt.Sequential(nnt.Conv2d(3, 16, 3), nnt.Conv2d(16, 32, 3)),
        nnt.Sequential(
            nnt.Conv2d(32, 64, 3),
            nnt.Conv2d(64, 128, 3),
        ), nnt.BatchNorm2d(128), nnt.GroupNorm(128, 4),
        nnt.LayerNorm((None, 128, 5, 5)), nnt.GlobalAvgPool2D(),
        nnt.FC(128, 1)).to(device)

    net_pt_sn = deepcopy(net)
    T.manual_seed(seed)
    if cuda_available:
        T.cuda.manual_seed_all(seed)

    net_pt_sn[0][0] = nn.utils.spectral_norm(net_pt_sn[0][0])
    net_pt_sn[0][1] = nn.utils.spectral_norm(net_pt_sn[0][1])
    net_pt_sn[1][0] = nn.utils.spectral_norm(net_pt_sn[1][0])
    net_pt_sn[1][1] = nn.utils.spectral_norm(net_pt_sn[1][1])
    net_pt_sn[6] = nn.utils.spectral_norm(net_pt_sn[6])

    T.manual_seed(seed)
    if cuda_available:
        T.cuda.manual_seed_all(seed)

    net_nnt_sn = nnt.spectral_norm(net)

    net_pt_sn(input)
    net_nnt_sn(input)

    assert not hasattr(net_nnt_sn[2], 'weight_u')
    assert not hasattr(net_nnt_sn[3], 'weight_u')
    assert not hasattr(net_nnt_sn[4], 'weight_u')

    testing.assert_allclose(net_pt_sn[0][0].weight, net_nnt_sn[0][0].weight)
    testing.assert_allclose(net_pt_sn[0][1].weight, net_nnt_sn[0][1].weight)
    testing.assert_allclose(net_pt_sn[1][0].weight, net_nnt_sn[1][0].weight)
    testing.assert_allclose(net_pt_sn[1][1].weight, net_nnt_sn[1][1].weight)
    testing.assert_allclose(net_pt_sn[6].weight, net_nnt_sn[6].weight)