Пример #1
0
    def forward(self, x):
        concentration = softplus(self.dir_concentration)
        loc = self.dir_loc / self.dir_loc .norm(dim=-1, keepdim=True)
        self.dir_sampler = PowerSpherical(loc, concentration)
        e = self.nonlinear(self.embed1(x))
        e = self.nonlinear(self.embed2(e))
        self.rad_sampler = LogNormal(self.rad_mu(e), softplus(self.rad_scale(e)))
        self.rad_sampler1 = LogNormal(self.rad_mu1(e), softplus(self.rad_scale1(e)))
        self.bias_sampler = Normal(self.bias_mu, softplus(self.bias_scale))

        direction_sample = self.dir_sampler.rsample()
        
        
        radius_sample = self.rad_sampler.rsample()
        radius_sample = (radius_sample * self.rad_sampler1.rsample()) ** 0.5
        radius_sample = radius_sample ** 0.5

        # radius_sample = LogNormal(self.rad_mu, softplus(self.rad_scale)).rsample()
        # radius_sample = (radius_sample * LogNormal(self.rad_mu1, softplus(self.rad_scale1)).rsample()) ** 0.5

        bias = self.bias_sampler.rsample() if self.bias else None

        # weight = direction_sample * radius_sample.unsqueeze(0) ** 0.5
        weight = direction_sample
        
        output = F.linear(x*radius_sample, weight, bias)
        
        return output
Пример #2
0
    def forward(self, input, sample=False):
        # self.dir_loc.data /= torch.sum(self.dir_loc.data ** 2, dim=-1, keepdim=True) ** 0.5
        # direction_sample = self.dir_rsampler(1, sample)[0]
        if sample:
            direction_sample = PowerSpherical(
                self.dir_loc,
                softplus(self.dir_softplus_inv_concentration)).rsample()
            radius_sample = LogNormal(self.rad_mu,
                                      softplus(self.rad_rho)).rsample()
        else:
            direction_sample = PowerSpherical(
                self.dir_loc,
                softplus(self.dir_softplus_inv_concentration)).mean
            radius_sample = LogNormal(self.rad_mu, softplus(self.rad_rho)).mean

        weight = direction_sample * radius_sample  #.unsqueeze(-1)
        return F.linear(input, weight, self.bias)
Пример #3
0
 def forward(self, noise):
     x = self.nonlinear(self.fc1(noise))
     # x = self.nonlinear(self.fc2(x))
     loc = self.fc_loc(x)
     loc = loc / loc.norm(dim=-1, keepdim=True)
     concentration = softplus(self.fc_concentration(x).squeeze()) + 1
     # the `+ 1` prevent collapsing behaviors
     sample = PowerSpherical(loc, concentration).rsample(torch.Size([1]))
     return sample
Пример #4
0
    def forward(self, input, sample=False):

        if sample:
            direction_sample = PowerSpherical(
                self.dir_loc,
                softplus(self.dir_softplus_inv_concentration)).rsample()
        else:
            direction_sample = PowerSpherical(
                self.dir_loc,
                softplus(self.dir_softplus_inv_concentration)).mean

        radius = self.gate(self.rad_layer(input))

        weight = direction_sample.unsqueeze(0) * radius.unsqueeze(-1)
        if self.bias is not None:
            output = (input.unsqueeze(1) * weight).sum(-1) + self.bias
        else:
            output = (input.unsqueeze(1) * weight).sum(-1)
        return output
Пример #5
0
class NewLinear(nn.Module):
    """docstring for NewLinear"""
    def __init__(self, in_features, out_features, bias=True, noise_shape=1):
        super(NewLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.dir_concentration = nn.Parameter(torch.Tensor(out_features))
        self.dir_loc = nn.Parameter(torch.Tensor(out_features, in_features))
        nn.init.kaiming_normal_(self.dir_loc)
        nn.init.normal_(self.dir_concentration, out_features*10, 1)
        self.rad_mu = nn.Linear(in_features, in_features)
        self.rad_scale = nn.Linear(in_features, in_features)
        self.rad_mu1 = nn.Linear(in_features, in_features)
        self.rad_scale1 = nn.Linear(in_features, in_features)
        self.embed1 = nn.Linear(in_features, in_features)
        self.embed2 = nn.Linear(in_features, in_features)
        self.nonlinear = nn.ReLU()

        # self.rad_mu = nn.Parameter(torch.Tensor(in_features))
        # self.rad_scale = nn.Parameter(torch.Tensor(in_features))
        # self.rad_mu1 = nn.Parameter(torch.Tensor(in_features))
        # self.rad_scale1 = nn.Parameter(torch.Tensor(in_features))
        # nn.init.normal_(self.rad_mu, math.log(2.0), 0.0001)
        # nn.init.normal_(self.rad_scale, softplus_inv(0.0001), 0.0001)
        # nn.init.normal_(self.rad_mu1, math.log(2.0), 0.0001)
        # nn.init.normal_(self.rad_scale1, softplus_inv(0.0001), 0.0001)

        
        self.bias_mu = nn.Parameter(torch.Tensor(out_features))
        self.bias_scale = nn.Parameter(torch.Tensor(out_features))
        nn.init.normal_(self.bias_mu, 0.0, 0.0001)
        nn.init.normal_(self.bias_scale, softplus_inv(0.0001), 0.0001)

    def forward(self, x):
        concentration = softplus(self.dir_concentration)
        loc = self.dir_loc / self.dir_loc .norm(dim=-1, keepdim=True)
        self.dir_sampler = PowerSpherical(loc, concentration)
        e = self.nonlinear(self.embed1(x))
        e = self.nonlinear(self.embed2(e))
        self.rad_sampler = LogNormal(self.rad_mu(e), softplus(self.rad_scale(e)))
        self.rad_sampler1 = LogNormal(self.rad_mu1(e), softplus(self.rad_scale1(e)))
        self.bias_sampler = Normal(self.bias_mu, softplus(self.bias_scale))

        direction_sample = self.dir_sampler.rsample()
        
        
        radius_sample = self.rad_sampler.rsample()
        radius_sample = (radius_sample * self.rad_sampler1.rsample()) ** 0.5
        radius_sample = radius_sample ** 0.5

        # radius_sample = LogNormal(self.rad_mu, softplus(self.rad_scale)).rsample()
        # radius_sample = (radius_sample * LogNormal(self.rad_mu1, softplus(self.rad_scale1)).rsample()) ** 0.5

        bias = self.bias_sampler.rsample() if self.bias else None

        # weight = direction_sample * radius_sample.unsqueeze(0) ** 0.5
        weight = direction_sample
        
        output = F.linear(x*radius_sample, weight, bias)
        
        return output

    def kl_divergence(self):
        pass
Пример #6
0
    def custom_regularization(self,
                              saver_net,
                              trainer_net,
                              mini_batch_size,
                              loss=None):

        dir_loc_reg_sum = mu_bias_reg_sum = rad_mu_reg_sum = 0
        L1_rad_mu_reg_sum = L1_mu_bias_reg_sum = 0
        rad_sigma_reg_sum = rad_sigma_normal_reg_sum = 0

        out_features_max = 512
        alpha = self.alpha
        if self.saved:
            alpha = 1

        if 'conv' in self.model_name:
            if self.data_name == 'omniglot':
                prev_weight_strength = nn.Parameter(
                    torch.Tensor(1, 1, 1, 1).uniform_(0, 0)).cuda()
            elif self.data_name == 'cifa':
                prev_weight_strength = nn.Parameter(
                    torch.Tensor(3, 1, 1, 1).uniform_(0, 0)).cuda()
        else:
            prev_weight_strength = nn.Parameter(
                torch.Tensor(28 * 28, 1).uniform_(0, 0)).cuda()

        for (saver_name,
             saver_layer), (trainer_name,
                            trainer_layer) in zip(saver_net.items(),
                                                  trainer_net.items()):
            # calculate mu regularization
            trainer_dir_loc = trainer_layer['dir_loc']
            trainer_dir_concentration = F.softplus(
                trainer_layer['dir_softplus_inv_concentration'])
            # trainer_dir_loc = trainer_layer['dir_rsampler.loc']
            # trainer_dir_concentration = F.softplus(trainer_layer['dir_rsampler.softplus_inv_concentration'])
            trainer_rad_mu = trainer_layer['rad_mu']
            trainer_rad_sigma = F.softplus(trainer_layer['rad_rho'])
            trainer_bias = trainer_layer['bias']

            saver_dir_loc = saver_layer['dir_loc']
            saver_dir_concentration = F.softplus(
                saver_layer['dir_softplus_inv_concentration'])
            # saver_dir_loc = saver_layer['dir_rsampler.loc']
            # saver_dir_concentration = F.softplus(saver_layer['dir_rsampler.softplus_inv_concentration'])
            saver_rad_mu = saver_layer['rad_mu']
            saver_rad_sigma = F.softplus(saver_layer['rad_rho'])
            saver_bias = saver_layer['bias']

            fan_in, fan_out = _calculate_fan_in_and_fan_out(trainer_dir_loc)

            concentration_init = ml_kappa(dim=fan_in, eps=self.model.eps)

            if 'fc' in trainer_name:
                std_init = math.sqrt((2 / fan_in) * self.model.ratio)
            if 'conv' in trainer_name:
                std_init = math.sqrt((2 / fan_out) * self.model.ratio)

            saver_weight_strength = (std_init / saver_rad_sigma)

            if len(saver_dir_loc.shape) == 4:
                out_features, in_features, _, _ = saver_dir_loc.shape
                curr_strength = saver_weight_strength.expand(
                    out_features, in_features, 1, 1)
                prev_strength = prev_weight_strength.permute(
                    1, 0, 2, 3).expand(out_features, in_features, 1, 1)

            else:
                out_features, in_features = saver_dir_loc.shape
                curr_strength = saver_weight_strength.expand(
                    out_features, in_features)
                if len(prev_weight_strength.shape) == 4:
                    feature_size = in_features // (
                        prev_weight_strength.shape[0])
                    prev_weight_strength = prev_weight_strength.reshape(
                        prev_weight_strength.shape[0], -1)
                    prev_weight_strength = prev_weight_strength.expand(
                        prev_weight_strength.shape[0], feature_size)
                    prev_weight_strength = prev_weight_strength.reshape(-1, 1)
                prev_strength = prev_weight_strength.permute(1, 0).expand(
                    out_features, in_features)

            L2_strength = torch.max(curr_strength, prev_strength)  #(4)
            #L2_strength = (1.0 / saver_weight_sigma) #(3a)
            bias_strength = torch.squeeze(saver_weight_strength)
            rad_mu_strength = torch.squeeze(saver_weight_strength)

            L1_sigma = saver_rad_sigma
            bias_sigma = torch.squeeze(saver_rad_sigma)

            prev_weight_strength = saver_weight_strength

            dir_loc_reg = (L2_strength *
                           (trainer_dir_loc - saver_dir_loc)).norm(2)**2
            mu_bias_reg = (bias_strength *
                           (trainer_bias - saver_bias)).norm(2)**2
            rad_mu_reg = (rad_mu_strength *
                          (trainer_rad_mu - saver_rad_mu)).norm(2)**2
            # (5)
            L1_rad_mu_reg = (torch.div(saver_rad_mu**2, L1_sigma**2) *
                             (trainer_rad_mu - saver_rad_mu)).norm(1)
            L1_mu_bias_reg = (torch.div(saver_bias**2, bias_sigma**2) *
                              (trainer_bias - saver_bias)).norm(1)

            L1_rad_mu_reg = L1_rad_mu_reg * (std_init**2)
            L1_mu_bias_reg = L1_mu_bias_reg * (std_init**2)
            #
            rad_sigma = (trainer_rad_sigma**2 / saver_rad_sigma**2)

            normal_rad_sigma = trainer_rad_sigma**2

            rad_sigma_reg_sum = rad_sigma_reg_sum + (
                rad_sigma - torch.log(rad_sigma)).sum()  # (3b)
            # rad_sigma_normal_reg_sum = rad_sigma_normal_reg_sum + (normal_rad_sigma - torch.log(normal_rad_sigma)).sum() #(6)

            # dir_loc_reg_sum = dir_loc_reg_sum + dir_loc_reg
            mu_bias_reg_sum = mu_bias_reg_sum + mu_bias_reg
            rad_mu_reg_sum = rad_mu_reg_sum + rad_mu_reg
            L1_rad_mu_reg_sum = L1_rad_mu_reg_sum + L1_rad_mu_reg
            L1_mu_bias_reg_sum = L1_mu_bias_reg_sum + L1_mu_bias_reg

        # elbo loss
        loss = loss / mini_batch_size
        # L2 loss
        loss = loss + alpha * (mu_bias_reg_sum +
                               rad_mu_reg_sum) / (2 * mini_batch_size)

        # loss = loss + self.saved * dir_loc_reg_sum / (mini_batch_size)
        # L1 loss
        loss = loss + self.saved * (L1_rad_mu_reg_sum +
                                    L1_mu_bias_reg_sum) / (mini_batch_size)
        # sigma regularization
        loss = loss + alpha * (rad_sigma_reg_sum) / (mini_batch_size)

        q_dist = PowerSpherical(trainer_dir_loc, trainer_dir_concentration)
        p_dist = PowerSpherical(saver_dir_loc, saver_dir_concentration)
        kld_dir = KL_Powerspherical(q_dist, p_dist)

        # reg_strength = L2_strength if self.saved else 1
        # kld_dir = KL_vMF_kappa_full(trainer_dir_loc, trainer_dir_concentration, saver_dir_loc, saver_dir_concentration, 1)
        loss = loss + alpha * kld_dir.sum() / (mini_batch_size)

        return loss
Пример #7
0
    def kl_divergence(self, saver_net, trainer_net):
        kld = 0
        prev_weight_strength = nn.Parameter(
            torch.Tensor(28 * 28, 1).uniform_(0, 0)).cuda()
        alpha = self.alpha
        if self.saved:
            alpha = 1
        for (saver_name,
             saver_layer), (trainer_name,
                            trainer_layer) in zip(saver_net.items(),
                                                  trainer_net.items()):

            trainer_dir_loc = trainer_layer['dir_loc']
            trainer_dir_concentration = F.softplus(
                trainer_layer['dir_softplus_inv_concentration'])
            trainer_rad_mu = trainer_layer['rad_mu']
            trainer_rad_sigma = F.softplus(trainer_layer['rad_rho'])
            trainer_bias = trainer_layer['bias']

            saver_dir_loc = saver_layer['dir_loc']
            saver_dir_concentration = F.softplus(
                saver_layer['dir_softplus_inv_concentration'])
            saver_rad_mu = saver_layer['rad_mu']
            saver_rad_sigma = F.softplus(saver_layer['rad_rho'])
            saver_bias = saver_layer['bias']

            fan_in, fan_out = _calculate_fan_in_and_fan_out(trainer_dir_loc)
            concentration_init = ml_kappa(dim=fan_in, eps=self.model.eps)

            if 'fc' in trainer_name:
                std_init = math.sqrt((2 / fan_in) * self.model.ratio)
            if 'conv' in trainer_name:
                std_init = math.sqrt((2 / fan_out) * self.model.ratio)

            out_features, in_features = saver_dir_loc.shape
            saver_weight_strength = (std_init / saver_rad_sigma)
            curr_strength = saver_weight_strength.expand(
                out_features, in_features)
            prev_strength = prev_weight_strength.permute(1, 0).expand(
                out_features, in_features)
            L2_strength = torch.max(curr_strength, prev_strength)

            prev_weight_strength = saver_weight_strength

            dir_loc_reg = (
                (L2_strength * trainer_dir_loc * saver_dir_loc) /
                (trainer_dir_loc.norm(2, dim=-1) *
                 saver_dir_loc.norm(2, dim=-1)).unsqueeze(-1)).sum()

            q_dir = PowerSpherical(trainer_dir_loc, trainer_dir_concentration)
            p_dir = PowerSpherical(saver_dir_loc, saver_dir_concentration)
            kld_dir = KL_Powerspherical(q_dir, p_dir)

            q_rad = LogNormal(trainer_rad_mu, trainer_rad_sigma)
            p_rad = LogNormal(saver_rad_mu, saver_rad_sigma)
            kld_rad = kl_divergence(q_rad, p_rad)

            mu_bias_reg = ((trainer_bias - saver_bias) /
                           saver_rad_sigma.squeeze()).norm(2)**2

            kld += kld_dir.sum(
            ) + 100 * kld_rad.sum() + 100 * mu_bias_reg + 100 * dir_loc_reg

        return kld