def __init__(self, input_shape, out_features, num_instances=None, bias=True, activation=None): super().__init__(input_shape) self.out_features = out_features self.activation = nnt.utils.function(activation) self.main = nnt.Sequential(input_shape=input_shape) self.main.add_module( 'conv_r_1', nnt.FC(self.main.output_shape, out_features, bias=bias, activation=activation)) self.main.add_module( 'convr_2', GraphXConv(self.main.output_shape, out_features, num_instances, bias, None)) if num_instances is None: self.res = (lambda x: x) if (out_features == input_shape[-1]) \ else nnt.FC(input_shape, out_features, bias=bias, activation=None) else: self.res = GraphXConv(input_shape, out_features, num_instances, activation=None)
def __init__(self, input_shape, out_features, num_instances=None, bias=True, activation=None, weights_init=None, bias_init=None): super().__init__(input_shape=input_shape) self.out_features = out_features self.num_instances = num_instances if num_instances else input_shape[1] self.activation = activation pattern = list(range(len(input_shape))) pattern[-1], pattern[-2] = pattern[-2], pattern[-1] self.add_module('dimshuffle1', nnt.DimShuffle(pattern, input_shape=self.output_shape)) self.add_module( 'conv_l', nnt.FC(self.output_shape, num_instances, bias=bias, activation=None, weights_init=weights_init, bias_init=bias_init)) self.add_module('dimshuffle2', nnt.DimShuffle(pattern, input_shape=self.output_shape)) self.add_module( 'conv_r', nnt.FC(self.output_shape, out_features, bias=bias, activation=activation))
def test_fc_layer(device, shape): out_features = 4 # test constructors fc_nnt = nnt.FC(shape, out_features) fc_nnt = nnt.FC((2, 3), out_features) fc_pt = T.nn.Linear(shape[1] if isinstance(shape, tuple) else shape, out_features) sanity_check(fc_nnt, fc_pt, shape=(2, 3), device=device)
def __init__(self, input_shape, out_features, bias=True, activation=None): super().__init__(input_shape, out_features, bias=bias, activation=activation) self.main = nnt.Sequential(input_shape=input_shape) self.main.add_module( 'fc1', nnt.FC(self.main.output_shape, out_features, activation=activation)) self.main.add_module( 'fc2', nnt.FC(self.main.output_shape, out_features, activation=None))
def __init__(self, input_shape, out_features, num_instances=None, rank=None, bias=True, activation=None, weights_init=None, bias_init=None): super().__init__(input_shape=input_shape) self.out_features = out_features self.num_instances = num_instances if num_instances else input_shape[1] self.rank = rank if rank is not None else self.num_instances // 2 assert self.rank < self.num_instances, 'rank should be smaller than num_instances' self.activation = activation pattern = list(range(len(input_shape))) pattern[-1], pattern[-2] = pattern[-2], pattern[-1] self.add_module('dimshuffle1', nnt.DimShuffle(pattern, input_shape=self.output_shape)) self.add_module( 'conv_l1', nnt.FC(self.output_shape, self.rank, bias=False, activation=None, weights_init=weights_init)) self.add_module( 'conv_l2', nnt.FC(self.output_shape, self.num_instances, bias=bias, activation=None, weights_init=weights_init, bias_init=bias_init)) self.add_module('dimshuffle2', nnt.DimShuffle(pattern, input_shape=self.output_shape)) self.add_module( 'conv_r', nnt.FC(self.output_shape, out_features, bias=bias, activation=activation))
def test_spectral_norm(device): from copy import deepcopy import torch.nn as nn seed = 48931 input = T.rand(10, 3, 5, 5).to(device) net = nnt.Sequential( nnt.Sequential(nnt.Conv2d(3, 16, 3), nnt.Conv2d(16, 32, 3)), nnt.Sequential( nnt.Conv2d(32, 64, 3), nnt.Conv2d(64, 128, 3), ), nnt.BatchNorm2d(128), nnt.GroupNorm(128, 4), nnt.LayerNorm((None, 128, 5, 5)), nnt.GlobalAvgPool2D(), nnt.FC(128, 1)).to(device) net_pt_sn = deepcopy(net) T.manual_seed(seed) if cuda_available: T.cuda.manual_seed_all(seed) net_pt_sn[0][0] = nn.utils.spectral_norm(net_pt_sn[0][0]) net_pt_sn[0][1] = nn.utils.spectral_norm(net_pt_sn[0][1]) net_pt_sn[1][0] = nn.utils.spectral_norm(net_pt_sn[1][0]) net_pt_sn[1][1] = nn.utils.spectral_norm(net_pt_sn[1][1]) net_pt_sn[6] = nn.utils.spectral_norm(net_pt_sn[6]) T.manual_seed(seed) if cuda_available: T.cuda.manual_seed_all(seed) net_nnt_sn = nnt.spectral_norm(net) net_pt_sn(input) net_nnt_sn(input) assert not hasattr(net_nnt_sn[2], 'weight_u') assert not hasattr(net_nnt_sn[3], 'weight_u') assert not hasattr(net_nnt_sn[4], 'weight_u') testing.assert_allclose(net_pt_sn[0][0].weight, net_nnt_sn[0][0].weight) testing.assert_allclose(net_pt_sn[0][1].weight, net_nnt_sn[0][1].weight) testing.assert_allclose(net_pt_sn[1][0].weight, net_nnt_sn[1][0].weight) testing.assert_allclose(net_pt_sn[1][1].weight, net_nnt_sn[1][1].weight) testing.assert_allclose(net_pt_sn[6].weight, net_nnt_sn[6].weight)