class WN_Conv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, train_scale=False, init_stdv=1.0): super(WN_Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) if train_scale: self.weight_scale = Parameter(torch.Tensor(out_channels)) else: self.register_buffer('weight_scale', torch.Tensor(out_channels)) self.train_scale = train_scale self.init_mode = False self.init_stdv = init_stdv self._reset_parameters() def _reset_parameters(self): self.weight.data.normal_(std=0.05) if self.bias is not None: self.bias.data.zero_() if self.train_scale: self.weight_scale.data.fill_(1.) else: self.weight_scale.fill_(1.) def forward(self, input): if self.train_scale: weight_scale = self.weight_scale else: weight_scale = Variable(self.weight_scale) # normalize weight matrix and linear projection [out x in x h x w] # for each output dimension, normalize through (in, h, w) = (1, 2, 3) dims norm_weight = self.weight * ( weight_scale[:, None, None, None] / torch.sqrt( (self.weight ** 2).sum(3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True) + 1e-6)).expand_as( self.weight) activation = F.conv2d(input, norm_weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) if self.init_mode == True: mean_act = activation.mean(3).mean(2).mean(0).squeeze() activation = activation - mean_act[None, :, None, None].expand_as(activation) inv_stdv = self.init_stdv / torch.sqrt( (activation ** 2).mean(3, keepdim=True).mean(2, keepdim=True).mean(0, keepdim=True) + 1e-6).squeeze() activation = activation * inv_stdv[None, :, None, None].expand_as(activation) if self.train_scale: self.weight_scale.data = self.weight_scale.data * inv_stdv.data else: self.weight_scale = self.weight_scale * inv_stdv.data self.bias.data = - mean_act.data * inv_stdv.data else: if self.bias is not None: activation = activation + self.bias[None, :, None, None].expand_as(activation) return activation
class NoisyLinear(nn.Module): def __init__(self, in_features, out_features, bias=True): super(NoisyLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.Tensor(out_features, in_features) self.weight_epsilon = torch.Tensor(out_features, in_features) self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_sigma = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = torch.Tensor(out_features) self.bias_epsilon = torch.Tensor(out_features) self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_sigma = Parameter(torch.Tensor(out_features)) else: self.bias = None self.bias_epsilon = None self.register_parameter('bias_mu', None) self.register_parameter('bias_sigma', None) self.reset_parameters() self.sampled = False def sample(self): if self.training: self.weight_epsilon.normal_() self.weight = self.weight_epsilon.mul(self.weight_sigma).add_( self.weight_mu) if self.bias is not None: self.bias_epsilon.normal_() self.bias = self.bias_epsilon.mul(self.bias_sigma).add_( self.bias_mu) else: self.weight = self.weight_mu.detach() if self.bias is not None: self.bias = self.bias_mu.detach() self.sampled = True def reset_parameters(self): stdv = math.sqrt(3.0 / self.weight.size(1)) self.weight_mu.uniform_(-stdv, stdv) self.weight_sigma.fill_(0.017) if self.bias is not None: self.bias_mu.uniform_(-stdv, stdv) self.bias_sigma.fill_(0.017) def forward(self, input): if not self.sampled: self.sample() return F.linear(input, self.weight, self.bias) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class WN_Linear(nn.Linear): def __init__(self, in_features, out_features, bias=True, train_scale=False, init_stdv=1.0): super(WN_Linear, self).__init__(in_features, out_features, bias=bias) if train_scale: self.weight_scale = Parameter(torch.ones(self.out_features)) else: self.register_buffer('weight_scale', torch.Tensor(out_features)) self.train_scale = train_scale self.init_mode = False self.init_stdv = init_stdv self._reset_parameters() def _reset_parameters(self): self.weight.data.normal_(0, std=0.05) if self.bias is not None: self.bias.data.zero_() if self.train_scale: self.weight_scale.data.fill_(1.) else: self.weight_scale.fill_(1.) def forward(self, input): if self.train_scale: weight_scale = self.weight_scale else: weight_scale = Variable(self.weight_scale) # normalize weight matrix and linear projection norm_weight = self.weight * ( weight_scale.unsqueeze(1) / torch.sqrt((self.weight ** 2).sum(1, keepdim=True) + 1e-6)).expand_as( self.weight) activation = F.linear(input, norm_weight) if self.init_mode == True: mean_act = activation.mean(0).squeeze(0) activation = activation - mean_act.expand_as(activation) inv_stdv = self.init_stdv / torch.sqrt((activation ** 2).mean(0, keepdim=True) + 1e-6).squeeze(0) activation = activation * inv_stdv.expand_as(activation) if self.train_scale: self.weight_scale.data = self.weight_scale.data * inv_stdv.data else: self.weight_scale = self.weight_scale * inv_stdv.data self.bias.data = - mean_act.data * inv_stdv.data else: if self.bias is not None: activation = activation + self.bias.expand_as(activation) return activation
class TestModule(Module): def __init__(self, input_num): super(TestModule, self).__init__() self.param = Parameter(torch.Tensor(1, input_num)) self.reset_parameters() def reset_parameters(self): # fake data with torch.no_grad(): self.param.fill_(0) self.param[0, 0] = 0.25111 self.param[0, 1] = 0.5 def forward(self, input): return F.linear(input, self.param, None)
class Embedding(nn.Module): def __init__(self, n_tokens, latent_dim, padding_idx=None, init='truncnorm'): super(Embedding, self).__init__() self.n_tokens = n_tokens self.latent_dim = latent_dim if padding_idx is not None: if padding_idx > 0: assert padding_idx < self.n_tokens, \ 'padding_idx must be within n_tokens' elif padding_idx < 0: assert padding_idx >= -self.n_tokens, \ 'padding_idx must be within n_tokens' self.padding_idx = padding_idx self.init = init self.reset_parameters() def reset_parameters(self): with torch.no_grad(): if self.init == 'truncnorm': t = 1. / (self.n_tokens**(1 / 2)) weights = truncnorm.rvs(-t, t, size=[self.n_tokens, self.latent_dim]) self.weights = Parameter(torch.tensor(weights).float()) elif self.init == 'zeros': self.weights = Parameter( torch.Tensor(self.n_tokens, self.latent_dim)) self.weights.fill_(1.0) if self.padding_idx is not None: with torch.no_grad(): self.weights[self.padding_idx].zero_() def forward(self, x): x = F.embedding(x, self.weights, padding_idx=self.padding_idx) return x
class _GBN(nn.Module): __constants__ = [ 'track_running_stats', 'momentum', 'eps', 'weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked', 'num_features', 'affine' ] def __init__(self, opt, num_features, eps=1e-5, momentum=0.01, affine=True, track_running_stats=True): super(_GBN, self).__init__() self.opt = opt self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter( torch.Tensor(self.opt.micro_in_macro, 1, num_features, 1, 1)) self.bias = Parameter( torch.Tensor(self.opt.micro_in_macro, 1, num_features, 1, 1)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.running_mean = Parameter(torch.Tensor(self.opt.micro_in_macro, 1, num_features, 1, 1), requires_grad=False) self.running_var = Parameter(torch.Tensor(self.opt.micro_in_macro, 1, num_features, 1, 1), requires_grad=False) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) def reset_parameters(self): self.reset_running_stats() if self.affine: init.ones_(self.weight) init.zeros_(self.bias) def _check_input_dim(self, input): raise NotImplementedError def forward(self, input): self._check_input_dim(input) output = self.g_b_n(input, self.running_mean, self.running_var, self.weight, self.bias) return output def extra_repr(self): return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 'track_running_stats={track_running_stats}'.format(**self.__dict__) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super(_GBN, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def g_b_n(self, input, running_mean, running_var, weight, beta): N, C, H, W = input.size() G = self.opt.micro_in_macro input = input.view(G, N // G, C, H, W) mean = torch.mean(input, (1, 3, 4), keepdim=True) var = torch.var(input, (1, 3, 4), keepdim=True) if self.training: running_mean.data = running_mean.data * ( 1 - self.momentum) + mean * self.momentum running_var.data = running_var.data * ( 1 - self.momentum) + var * self.momentum X_hat = (input - mean) / torch.sqrt(var + self.eps) else: X_hat = (input - running_mean) / torch.sqrt(running_var + self.eps) X_hat = X_hat * weight + beta output = X_hat.view(N, C, H, W) return output