Пример #1
0
 def get_weight(self, input, reverse):
     w_shape = self.w_shape
     if not self.LU:
         pixels = thops.pixels(input)
         dlogdet = torch.slogdet(self.weight)[1] * pixels
         if not reverse:
             weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
         else:
             weight = torch.inverse(self.weight.double()).float()\
                           .view(w_shape[0], w_shape[1], 1, 1)
         return weight, dlogdet
     else:
         # print('using LU decomposition !!!')
         self.p = self.p.to(input.device)
         self.sign_s = self.sign_s.to(input.device)
         self.l_mask = self.l_mask.to(input.device)
         self.eye = self.eye.to(input.device)
         l = self.l * self.l_mask + self.eye
         u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
         dlogdet = thops.sum(self.log_s) * thops.pixels(input)
         if not reverse:
             w = torch.matmul(self.p, torch.matmul(l, u))
         else:
             l = torch.inverse(l.double()).float()
             u = torch.inverse(u.double()).float()
             w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
         return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet
Пример #2
0
    def _scale(self, input, logdet=None, reverse=False, offset=None):
        logs = self.logs

        if offset is not None:
            logs = logs + offset

        if not reverse:
            input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
            # input = input * torch.exp(logs+logs_offset)
        else:
            input = input * torch.exp(-logs)
        if logdet is not None:
            """
            logs is log_std of `mean of channels`
            so we need to multiply pixels
            """
            dlogdet = thops.sum(logs) * thops.pixels(input)
            if reverse:
                dlogdet *= -1
            logdet = logdet + dlogdet
        return input, logdet
Пример #3
0
 def get_score(self, disc_loss_sigma, z):
     score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
                  z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
     return -score_real
 def get_logdet(self, scale):
     return thops.sum(torch.log(scale), dim=[1, 2, 3])