def get_weight(self, input, reverse): w_shape = self.w_shape if not self.LU: pixels = thops.pixels(input) dlogdet = torch.slogdet(self.weight)[1] * pixels if not reverse: weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) else: weight = torch.inverse(self.weight.double()).float()\ .view(w_shape[0], w_shape[1], 1, 1) return weight, dlogdet else: # print('using LU decomposition !!!') self.p = self.p.to(input.device) self.sign_s = self.sign_s.to(input.device) self.l_mask = self.l_mask.to(input.device) self.eye = self.eye.to(input.device) l = self.l * self.l_mask + self.eye u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) dlogdet = thops.sum(self.log_s) * thops.pixels(input) if not reverse: w = torch.matmul(self.p, torch.matmul(l, u)) else: l = torch.inverse(l.double()).float() u = torch.inverse(u.double()).float() w = torch.matmul(u, torch.matmul(l, self.p.inverse())) return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): logdet = torch.zeros_like(lr[:, 0, 0, 0]) pixels = thops.pixels(lr) * self.opt['scale']**2 if add_gt_noise: logdet = logdet - float(-np.log(self.quant) * pixels) if lr_enc is None: lr_enc = self.rrdbPreprocessing(lr) x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, logdet=logdet) return x, logdet
def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None): if lr_enc is None: lr_enc = self.rrdbPreprocessing(lr) logdet = torch.zeros_like(gt[:, 0, 0, 0]) pixels = thops.pixels(gt) z = gt if add_gt_noise: # Setup noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) if noiseQuant: z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) logdet = logdet + float(-np.log(self.quant) * pixels) # Encode epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses, y_onehot=y_onehot) objective = logdet.clone() if isinstance(epses, (list, tuple)): z = epses[-1] else: z = epses objective = objective + flow.GaussianDiag.logp(None, None, z) nll = (-objective) / float(np.log(2.) * pixels) if isinstance(epses, list): return epses, nll, logdet return z, nll, logdet
def get_weight(self, input, reverse): w_shape = self.w_shape pixels = thops.pixels(input) dlogdet = torch.slogdet(self.weight)[1] * pixels if not reverse: weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) else: weight = torch.inverse(self.weight.double()).float() \ .view(w_shape[0], w_shape[1], 1, 1) return weight, dlogdet
def _scale(self, input, logdet=None, reverse=False, offset=None): logs = self.logs if offset is not None: logs = logs + offset if not reverse: input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 # input = input * torch.exp(logs+logs_offset) else: input = input * torch.exp(-logs) if logdet is not None: """ logs is log_std of `mean of channels` so we need to multiply pixels """ dlogdet = thops.sum(logs) * thops.pixels(input) if reverse: dlogdet *= -1 logdet = logdet + dlogdet return input, logdet