def forward(self, x): concentration = softplus(self.dir_concentration) loc = self.dir_loc / self.dir_loc .norm(dim=-1, keepdim=True) self.dir_sampler = PowerSpherical(loc, concentration) e = self.nonlinear(self.embed1(x)) e = self.nonlinear(self.embed2(e)) self.rad_sampler = LogNormal(self.rad_mu(e), softplus(self.rad_scale(e))) self.rad_sampler1 = LogNormal(self.rad_mu1(e), softplus(self.rad_scale1(e))) self.bias_sampler = Normal(self.bias_mu, softplus(self.bias_scale)) direction_sample = self.dir_sampler.rsample() radius_sample = self.rad_sampler.rsample() radius_sample = (radius_sample * self.rad_sampler1.rsample()) ** 0.5 radius_sample = radius_sample ** 0.5 # radius_sample = LogNormal(self.rad_mu, softplus(self.rad_scale)).rsample() # radius_sample = (radius_sample * LogNormal(self.rad_mu1, softplus(self.rad_scale1)).rsample()) ** 0.5 bias = self.bias_sampler.rsample() if self.bias else None # weight = direction_sample * radius_sample.unsqueeze(0) ** 0.5 weight = direction_sample output = F.linear(x*radius_sample, weight, bias) return output
def forward(self, input, sample=False): # self.dir_loc.data /= torch.sum(self.dir_loc.data ** 2, dim=-1, keepdim=True) ** 0.5 # direction_sample = self.dir_rsampler(1, sample)[0] if sample: direction_sample = PowerSpherical( self.dir_loc, softplus(self.dir_softplus_inv_concentration)).rsample() radius_sample = LogNormal(self.rad_mu, softplus(self.rad_rho)).rsample() else: direction_sample = PowerSpherical( self.dir_loc, softplus(self.dir_softplus_inv_concentration)).mean radius_sample = LogNormal(self.rad_mu, softplus(self.rad_rho)).mean weight = direction_sample * radius_sample #.unsqueeze(-1) return F.linear(input, weight, self.bias)
def forward(self, noise): x = self.nonlinear(self.fc1(noise)) # x = self.nonlinear(self.fc2(x)) loc = self.fc_loc(x) loc = loc / loc.norm(dim=-1, keepdim=True) concentration = softplus(self.fc_concentration(x).squeeze()) + 1 # the `+ 1` prevent collapsing behaviors sample = PowerSpherical(loc, concentration).rsample(torch.Size([1])) return sample
def forward(self, input, sample=False): if sample: direction_sample = PowerSpherical( self.dir_loc, softplus(self.dir_softplus_inv_concentration)).rsample() else: direction_sample = PowerSpherical( self.dir_loc, softplus(self.dir_softplus_inv_concentration)).mean radius = self.gate(self.rad_layer(input)) weight = direction_sample.unsqueeze(0) * radius.unsqueeze(-1) if self.bias is not None: output = (input.unsqueeze(1) * weight).sum(-1) + self.bias else: output = (input.unsqueeze(1) * weight).sum(-1) return output
class NewLinear(nn.Module): """docstring for NewLinear""" def __init__(self, in_features, out_features, bias=True, noise_shape=1): super(NewLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.bias = bias self.dir_concentration = nn.Parameter(torch.Tensor(out_features)) self.dir_loc = nn.Parameter(torch.Tensor(out_features, in_features)) nn.init.kaiming_normal_(self.dir_loc) nn.init.normal_(self.dir_concentration, out_features*10, 1) self.rad_mu = nn.Linear(in_features, in_features) self.rad_scale = nn.Linear(in_features, in_features) self.rad_mu1 = nn.Linear(in_features, in_features) self.rad_scale1 = nn.Linear(in_features, in_features) self.embed1 = nn.Linear(in_features, in_features) self.embed2 = nn.Linear(in_features, in_features) self.nonlinear = nn.ReLU() # self.rad_mu = nn.Parameter(torch.Tensor(in_features)) # self.rad_scale = nn.Parameter(torch.Tensor(in_features)) # self.rad_mu1 = nn.Parameter(torch.Tensor(in_features)) # self.rad_scale1 = nn.Parameter(torch.Tensor(in_features)) # nn.init.normal_(self.rad_mu, math.log(2.0), 0.0001) # nn.init.normal_(self.rad_scale, softplus_inv(0.0001), 0.0001) # nn.init.normal_(self.rad_mu1, math.log(2.0), 0.0001) # nn.init.normal_(self.rad_scale1, softplus_inv(0.0001), 0.0001) self.bias_mu = nn.Parameter(torch.Tensor(out_features)) self.bias_scale = nn.Parameter(torch.Tensor(out_features)) nn.init.normal_(self.bias_mu, 0.0, 0.0001) nn.init.normal_(self.bias_scale, softplus_inv(0.0001), 0.0001) def forward(self, x): concentration = softplus(self.dir_concentration) loc = self.dir_loc / self.dir_loc .norm(dim=-1, keepdim=True) self.dir_sampler = PowerSpherical(loc, concentration) e = self.nonlinear(self.embed1(x)) e = self.nonlinear(self.embed2(e)) self.rad_sampler = LogNormal(self.rad_mu(e), softplus(self.rad_scale(e))) self.rad_sampler1 = LogNormal(self.rad_mu1(e), softplus(self.rad_scale1(e))) self.bias_sampler = Normal(self.bias_mu, softplus(self.bias_scale)) direction_sample = self.dir_sampler.rsample() radius_sample = self.rad_sampler.rsample() radius_sample = (radius_sample * self.rad_sampler1.rsample()) ** 0.5 radius_sample = radius_sample ** 0.5 # radius_sample = LogNormal(self.rad_mu, softplus(self.rad_scale)).rsample() # radius_sample = (radius_sample * LogNormal(self.rad_mu1, softplus(self.rad_scale1)).rsample()) ** 0.5 bias = self.bias_sampler.rsample() if self.bias else None # weight = direction_sample * radius_sample.unsqueeze(0) ** 0.5 weight = direction_sample output = F.linear(x*radius_sample, weight, bias) return output def kl_divergence(self): pass
def custom_regularization(self, saver_net, trainer_net, mini_batch_size, loss=None): dir_loc_reg_sum = mu_bias_reg_sum = rad_mu_reg_sum = 0 L1_rad_mu_reg_sum = L1_mu_bias_reg_sum = 0 rad_sigma_reg_sum = rad_sigma_normal_reg_sum = 0 out_features_max = 512 alpha = self.alpha if self.saved: alpha = 1 if 'conv' in self.model_name: if self.data_name == 'omniglot': prev_weight_strength = nn.Parameter( torch.Tensor(1, 1, 1, 1).uniform_(0, 0)).cuda() elif self.data_name == 'cifa': prev_weight_strength = nn.Parameter( torch.Tensor(3, 1, 1, 1).uniform_(0, 0)).cuda() else: prev_weight_strength = nn.Parameter( torch.Tensor(28 * 28, 1).uniform_(0, 0)).cuda() for (saver_name, saver_layer), (trainer_name, trainer_layer) in zip(saver_net.items(), trainer_net.items()): # calculate mu regularization trainer_dir_loc = trainer_layer['dir_loc'] trainer_dir_concentration = F.softplus( trainer_layer['dir_softplus_inv_concentration']) # trainer_dir_loc = trainer_layer['dir_rsampler.loc'] # trainer_dir_concentration = F.softplus(trainer_layer['dir_rsampler.softplus_inv_concentration']) trainer_rad_mu = trainer_layer['rad_mu'] trainer_rad_sigma = F.softplus(trainer_layer['rad_rho']) trainer_bias = trainer_layer['bias'] saver_dir_loc = saver_layer['dir_loc'] saver_dir_concentration = F.softplus( saver_layer['dir_softplus_inv_concentration']) # saver_dir_loc = saver_layer['dir_rsampler.loc'] # saver_dir_concentration = F.softplus(saver_layer['dir_rsampler.softplus_inv_concentration']) saver_rad_mu = saver_layer['rad_mu'] saver_rad_sigma = F.softplus(saver_layer['rad_rho']) saver_bias = saver_layer['bias'] fan_in, fan_out = _calculate_fan_in_and_fan_out(trainer_dir_loc) concentration_init = ml_kappa(dim=fan_in, eps=self.model.eps) if 'fc' in trainer_name: std_init = math.sqrt((2 / fan_in) * self.model.ratio) if 'conv' in trainer_name: std_init = math.sqrt((2 / fan_out) * self.model.ratio) saver_weight_strength = (std_init / saver_rad_sigma) if len(saver_dir_loc.shape) == 4: out_features, in_features, _, _ = saver_dir_loc.shape curr_strength = saver_weight_strength.expand( out_features, in_features, 1, 1) prev_strength = prev_weight_strength.permute( 1, 0, 2, 3).expand(out_features, in_features, 1, 1) else: out_features, in_features = saver_dir_loc.shape curr_strength = saver_weight_strength.expand( out_features, in_features) if len(prev_weight_strength.shape) == 4: feature_size = in_features // ( prev_weight_strength.shape[0]) prev_weight_strength = prev_weight_strength.reshape( prev_weight_strength.shape[0], -1) prev_weight_strength = prev_weight_strength.expand( prev_weight_strength.shape[0], feature_size) prev_weight_strength = prev_weight_strength.reshape(-1, 1) prev_strength = prev_weight_strength.permute(1, 0).expand( out_features, in_features) L2_strength = torch.max(curr_strength, prev_strength) #(4) #L2_strength = (1.0 / saver_weight_sigma) #(3a) bias_strength = torch.squeeze(saver_weight_strength) rad_mu_strength = torch.squeeze(saver_weight_strength) L1_sigma = saver_rad_sigma bias_sigma = torch.squeeze(saver_rad_sigma) prev_weight_strength = saver_weight_strength dir_loc_reg = (L2_strength * (trainer_dir_loc - saver_dir_loc)).norm(2)**2 mu_bias_reg = (bias_strength * (trainer_bias - saver_bias)).norm(2)**2 rad_mu_reg = (rad_mu_strength * (trainer_rad_mu - saver_rad_mu)).norm(2)**2 # (5) L1_rad_mu_reg = (torch.div(saver_rad_mu**2, L1_sigma**2) * (trainer_rad_mu - saver_rad_mu)).norm(1) L1_mu_bias_reg = (torch.div(saver_bias**2, bias_sigma**2) * (trainer_bias - saver_bias)).norm(1) L1_rad_mu_reg = L1_rad_mu_reg * (std_init**2) L1_mu_bias_reg = L1_mu_bias_reg * (std_init**2) # rad_sigma = (trainer_rad_sigma**2 / saver_rad_sigma**2) normal_rad_sigma = trainer_rad_sigma**2 rad_sigma_reg_sum = rad_sigma_reg_sum + ( rad_sigma - torch.log(rad_sigma)).sum() # (3b) # rad_sigma_normal_reg_sum = rad_sigma_normal_reg_sum + (normal_rad_sigma - torch.log(normal_rad_sigma)).sum() #(6) # dir_loc_reg_sum = dir_loc_reg_sum + dir_loc_reg mu_bias_reg_sum = mu_bias_reg_sum + mu_bias_reg rad_mu_reg_sum = rad_mu_reg_sum + rad_mu_reg L1_rad_mu_reg_sum = L1_rad_mu_reg_sum + L1_rad_mu_reg L1_mu_bias_reg_sum = L1_mu_bias_reg_sum + L1_mu_bias_reg # elbo loss loss = loss / mini_batch_size # L2 loss loss = loss + alpha * (mu_bias_reg_sum + rad_mu_reg_sum) / (2 * mini_batch_size) # loss = loss + self.saved * dir_loc_reg_sum / (mini_batch_size) # L1 loss loss = loss + self.saved * (L1_rad_mu_reg_sum + L1_mu_bias_reg_sum) / (mini_batch_size) # sigma regularization loss = loss + alpha * (rad_sigma_reg_sum) / (mini_batch_size) q_dist = PowerSpherical(trainer_dir_loc, trainer_dir_concentration) p_dist = PowerSpherical(saver_dir_loc, saver_dir_concentration) kld_dir = KL_Powerspherical(q_dist, p_dist) # reg_strength = L2_strength if self.saved else 1 # kld_dir = KL_vMF_kappa_full(trainer_dir_loc, trainer_dir_concentration, saver_dir_loc, saver_dir_concentration, 1) loss = loss + alpha * kld_dir.sum() / (mini_batch_size) return loss
def kl_divergence(self, saver_net, trainer_net): kld = 0 prev_weight_strength = nn.Parameter( torch.Tensor(28 * 28, 1).uniform_(0, 0)).cuda() alpha = self.alpha if self.saved: alpha = 1 for (saver_name, saver_layer), (trainer_name, trainer_layer) in zip(saver_net.items(), trainer_net.items()): trainer_dir_loc = trainer_layer['dir_loc'] trainer_dir_concentration = F.softplus( trainer_layer['dir_softplus_inv_concentration']) trainer_rad_mu = trainer_layer['rad_mu'] trainer_rad_sigma = F.softplus(trainer_layer['rad_rho']) trainer_bias = trainer_layer['bias'] saver_dir_loc = saver_layer['dir_loc'] saver_dir_concentration = F.softplus( saver_layer['dir_softplus_inv_concentration']) saver_rad_mu = saver_layer['rad_mu'] saver_rad_sigma = F.softplus(saver_layer['rad_rho']) saver_bias = saver_layer['bias'] fan_in, fan_out = _calculate_fan_in_and_fan_out(trainer_dir_loc) concentration_init = ml_kappa(dim=fan_in, eps=self.model.eps) if 'fc' in trainer_name: std_init = math.sqrt((2 / fan_in) * self.model.ratio) if 'conv' in trainer_name: std_init = math.sqrt((2 / fan_out) * self.model.ratio) out_features, in_features = saver_dir_loc.shape saver_weight_strength = (std_init / saver_rad_sigma) curr_strength = saver_weight_strength.expand( out_features, in_features) prev_strength = prev_weight_strength.permute(1, 0).expand( out_features, in_features) L2_strength = torch.max(curr_strength, prev_strength) prev_weight_strength = saver_weight_strength dir_loc_reg = ( (L2_strength * trainer_dir_loc * saver_dir_loc) / (trainer_dir_loc.norm(2, dim=-1) * saver_dir_loc.norm(2, dim=-1)).unsqueeze(-1)).sum() q_dir = PowerSpherical(trainer_dir_loc, trainer_dir_concentration) p_dir = PowerSpherical(saver_dir_loc, saver_dir_concentration) kld_dir = KL_Powerspherical(q_dir, p_dir) q_rad = LogNormal(trainer_rad_mu, trainer_rad_sigma) p_rad = LogNormal(saver_rad_mu, saver_rad_sigma) kld_rad = kl_divergence(q_rad, p_rad) mu_bias_reg = ((trainer_bias - saver_bias) / saver_rad_sigma.squeeze()).norm(2)**2 kld += kld_dir.sum( ) + 100 * kld_rad.sum() + 100 * mu_bias_reg + 100 * dir_loc_reg return kld