def forward(self, x): h = x * self.c h = self.conv(h) if self.act is not None: h = self.act(h) if self.pixelnorm: mean = torch.mean(h * h, 1, keepdim=True) dom = torch.rsqrt(mean + self.eps) h = h * dom return h
def rsample(self, sample_shape=torch.Size()): # NOTE: This does not agree with scipy implementation as much as other distributions. # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor # parameters seems to help. # X ~ Normal(0, 1) # Z ~ Chi2(df) # Y = X / sqrt(Z / df) ~ StudentT(df) shape = self._extended_shape(sample_shape) X = self.df.new(*shape).normal_() Z = self._chi2.rsample(sample_shape) Y = X * torch.rsqrt(Z / self.df) return self.loc + self.scale * Y
def forward(self, x): h = x.unsqueeze(2).unsqueeze(3) if self.normalize_latents: mean = torch.mean(h * h, 1, keepdim=True) dom = torch.rsqrt(mean + self.eps) h = h * dom h = self.block0(h, self.depth == 0) if self.depth > 0: for i in range(self.depth - 1): h = F.upsample(h, scale_factor=2) h = self.blocks[i](h) h = F.upsample(h, scale_factor=2) ult = self.blocks[self.depth - 1](h, True) if self.alpha < 1.0: if self.depth > 1: preult_rgb = self.blocks[self.depth - 2].toRGB(h) else: preult_rgb = self.block0.toRGB(h) else: preult_rgb = 0 h = preult_rgb * (1-self.alpha) + ult * self.alpha return h
def forward(self, _inp): x, mask = _inp norm = torch.rsqrt((x**2).mean(dim=1, keepdim=True) + 1e-7) return x * norm, mask
def forward(self, x): x = x - torch.mean(x, (2, 3), True) tmp = torch.mul(x, x) # or x ** 2 tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) return x * tmp
def forward(self, x): x -= torch.mean(x, dim=[2, 3], keepdim=True) x *= torch.rsqrt( torch.mean(x**2, dim=[2, 3], keepdim=True) + self.epsilon) return x
def test_rsqrt(self): x = torch.randn(3, 4, requires_grad=True) self.assertONNX(lambda x: torch.rsqrt(x), x)
def forward(self, input): return input * torch.rsqrt( torch.mean(input**2, dim=1, keepdim=True) + 1e-8)
def forward(ctx, tensor, alpha=1): #pylint: disable=arguments-differ """Calculates the forward pass for an ISRLU unit""" negatives = torch.min(tensor, torch.tensor((0,), dtype=tensor.dtype)) nisr = torch.rsqrt(1. + alpha * (negatives ** 2)) ctx.save_for_backward(nisr) return tensor * nisr
def normalize(x, eps=1e-10): return x * torch.rsqrt(torch.sum(x**2, dim=1, keepdim=True) + eps)
def forward(self, x): self._check_input_dim(x) if self.training: N, C, H, W = x.size() G = self.groups x = x.transpose(0, 1).contiguous().view(C, -1) mu = x.mean(1, keepdim=True) x = x - mu xxt = torch.mm(x, x.t())/(N*H*W) + torch.eye(C, out=torch.empty_like(x)) * self.eps assert C % G == 0 length = int(C / G) xxti = torch.chunk(xxt, G, dim=0) xxtj = [torch.chunk(xxti[j], G, dim=1)[j] for j in range(G)] xg = list(torch.chunk(x, G, dim=0)) xgr_list = [] for i in range(G): u, e, v = torch.svd(xxtj[i]) ratio = torch.cumsum(e, 0) / e.sum() counter_i = 1 for j in range(length): if e[j] <= self.eps: # ratio[j] >= (1 - self.eps) or e[j] <= self.eps: # print('{}/{} eigen-vectors selected'.format(j + 1, length)) counter_i = j + 1 # at least keep 99.99% energy break subspace = torch.zeros_like(xxtj[i]) for j in range(counter_i): lambda_ij = e[j] if lambda_ij < 0: print('eigenvalues: ', e) print("Error message: negative SVD lambda_ij {} vs SVD lambda_ij {}..".format(lambda_ij, e[j])) break eigenvector_ij = v[:, j][..., None] subspace += torch.mm(eigenvector_ij, torch.rsqrt(lambda_ij) * eigenvector_ij.t()) xgr = torch.mm(subspace, xg[i]) xgr_list.append(xgr) with torch.no_grad(): running_subspace = self.__getattr__('running_subspace' + str(i)) running_subspace.data = (1 - self.momentum) * running_subspace.data + self.momentum * subspace.data with torch.no_grad(): self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu xr = torch.cat(xgr_list, dim=0) xr = xr * self.weight + self.bias xr = xr.view(C, N, H, W).transpose(0, 1) return xr else: N, C, H, W = x.size() x = x.transpose(0, 1).contiguous().view(C, -1) x = (x - self.running_mean) G = self.groups xg = list(torch.chunk(x, G, dim=0)) for i in range(G): subspace = self.__getattr__('running_subspace' + str(i)) xg[i] = torch.mm(subspace, xg[i]) x = torch.cat(xg, dim=0) x = x * self.weight + self.bias x = x.view(C, N, H, W).transpose(0, 1) return x
def cosineSim(wv1, wv2): mul = torch.mul(wv1, wv2) #mulabs = torch.sqrt(torch.sum(torch.mul(mul, mul))) wv1abs = torch.rsqrt(torch.sum(torch.mul(wv1, wv1))) wv2abs = torch.rsqrt(torch.sum(torch.mul(wv2, wv2))) return mul * wv1abs * wv2abs
def forward(self, x): self._check_input_dim(x) if self.training: N, C, H, W = x.size() G = self.groups x = x.transpose(0, 1).contiguous().view(C, -1) mu = x.mean(1, keepdim=True) x = x - mu xxt = torch.mm(x, x.t()) / (N * H * W) + torch.eye(C, out=torch.empty_like(x)) * self.eps assert C % G == 0 length = int(C/G) xxti = torch.chunk(xxt, G, dim=0) xxtj = [torch.chunk(xxti[j], G, dim=1)[j] for j in range(G)] xg = list(torch.chunk(x, G, dim=0)) xgr_list = [] for i in range(G): subspace = torch.zeros_like(xxtj[i]) for j in range(length): # initialize eigenvector with random values eigenvector_ij = self.__getattr__('eigenvector{}-{}'.format(i, j)) v = l2normalize(torch.randn_like(eigenvector_ij)) eigenvector_ij.data = v.data eigenvector_ij = self.power_layer(xxtj[i], eigenvector_ij) lambda_current = torch.mm(xxtj[i].mm(eigenvector_ij).t(), eigenvector_ij)/torch.mm(eigenvector_ij.t(), eigenvector_ij) if j == 0: lambda_ij = lambda_current elif lambda_ij < lambda_current or lambda_current < self.eps: break else: lambda_ij = lambda_current subspace += torch.mm(eigenvector_ij, torch.rsqrt(lambda_ij).mm(eigenvector_ij.t())) # remove projections on the eigenvectors xxtj[i] = xxtj[i] - torch.mm(xxtj[i], eigenvector_ij.mm(eigenvector_ij.t())) xgr = torch.mm(subspace, xg[i]) xgr_list.append(xgr) with torch.no_grad(): running_subspace = self.__getattr__('running_subspace' + str(i)) running_subspace.data = (1 - self.momentum) * running_subspace.data + self.momentum * subspace.data with torch.no_grad(): self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu xr = torch.cat(xgr_list, dim=0) xr = xr * self.weight + self.bias xr = xr.view(C, N, H, W).transpose(0, 1) return xr else: N, C, H, W = x.size() x = x.transpose(0, 1).contiguous().view(C, -1) x = (x - self.running_mean) G = self.groups xg = list(torch.chunk(x, G, dim=0)) for i in range(G): subspace = self.__getattr__('running_subspace' + str(i)) xg[i] = torch.mm(subspace, xg[i]) x = torch.cat(xg, dim=0) x = x * self.weight + self.bias x = x.view(C, N, H, W).transpose(0, 1) return x
def forward(self, x): self._check_input_dim(x) if self.training: N, C, H, W = x.size() G = self.groups x = x.transpose(0, 1).contiguous().view(C, -1) mu = x.mean(1, keepdim=True) x = x - mu xxt = torch.mm(x, x.t()) / (N * H * W) + torch.eye(C, out=torch.empty_like(x)) * self.eps assert C % G == 0 length = int(C/G) xxti = torch.chunk(xxt, G, dim=0) xxtj = [torch.chunk(xxti[j], G, dim=1)[j] for j in range(G)] xg = list(torch.chunk(x, G, dim=0)) xgr_list = [] for i in range(G): counter_i = 0 # compute eigenvectors of subgroups no grad with torch.no_grad(): u, e, v = torch.svd(xxtj[i]) ratio = torch.cumsum(e, 0)/e.sum() for j in range(length): if ratio[j] >= (1 - self.eps) or e[j] <= self.eps: print('{}/{} eigen-vectors selected'.format(j + 1, length)) print(e[0:counter_i]) break eigenvector_ij = self.__getattr__('eigenvector{}-{}'.format(i, j)) eigenvector_ij.data = v[:, j][..., None].data counter_i = j + 1 # feed eigenvectors to Power Iteration Layer with grad and compute whitened tensor subspace = torch.zeros_like(xxtj[i]) for j in range(counter_i): eigenvector_ij = self.__getattr__('eigenvector{}-{}'.format(i, j)) eigenvector_ij = self.power_layer(xxtj[i], eigenvector_ij) lambda_ij = torch.mm(xxtj[i].mm(eigenvector_ij).t(), eigenvector_ij)/torch.mm(eigenvector_ij.t(), eigenvector_ij) if lambda_ij < 0: print('eigenvalues: ', e) print("Warning message: negative PI lambda_ij {} vs SVD lambda_ij {}..".format(lambda_ij, e[j])) break diff_ratio = (lambda_ij - e[j]).abs()/e[j] if diff_ratio > 0.1: break subspace += torch.mm(eigenvector_ij, torch.rsqrt(lambda_ij).mm(eigenvector_ij.t())) xxtj[i] = xxtj[i] - torch.mm(xxtj[i], eigenvector_ij.mm(eigenvector_ij.t())) xgr = torch.mm(subspace, xg[i]) xgr_list.append(xgr) with torch.no_grad(): running_subspace = self.__getattr__('running_subspace' + str(i)) running_subspace.data = (1 - self.momentum) * running_subspace.data + self.momentum * subspace.data with torch.no_grad(): self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu xr = torch.cat(xgr_list, dim=0) xr = xr * self.weight + self.bias xr = xr.view(C, N, H, W).transpose(0, 1) return xr else: N, C, H, W = x.size() x = x.transpose(0, 1).contiguous().view(C, -1) x = (x - self.running_mean) G = self.groups xg = list(torch.chunk(x, G, dim=0)) for i in range(G): subspace = self.__getattr__('running_subspace' + str(i)) xg[i] = torch.mm(subspace, xg[i]) x = torch.cat(xg, dim=0) x = x * self.weight + self.bias x = x.view(C, N, H, W).transpose(0, 1) return x
def forward(self, input, style, labels=None): batch, in_channel, height, width = input.shape if not self.fused: weight = self.scale * self.weight.squeeze(0) if self.conditional_bias: style = self.modulation(style, labels) else: style = self.modulation(style) if self.demodulate: w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() input = input * style.reshape(batch, in_channel, 1, 1) if self.upsample: weight = weight.transpose(0, 1) out = conv2d_gradfix.conv_transpose2d(input, weight, padding=0, stride=2) out = self.blur(out) elif self.downsample: input = self.blur(input) out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) else: out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) if self.demodulate: out = out * dcoefs.view(batch, -1, 1, 1) return out if self.conditional_bias: style = self.modulation(style, labels).view(batch, 1, in_channel, 1, 1) else: style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size) out = conv2d_gradfix.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = conv2d_gradfix.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out
# # variance = torch.var(prob_i, 0) # # print(variance) # # variance = torch.rsqrt(variance) # variance = torch.log(variance) # print(variance) # a1=torch.tensor([5.0]) # print(a1-torch.mean(a1)) # print(torch.norm(a1-torch.mean(a1),p=2)) # print(torch.log(1/torch.norm(a1-torch.mean(a1),p=2))) # print(torch.var(a1,0)) # print(a[0].index_select(0,torch.tensor([0,2,4]))) # print(torch.index_select(a, mask)) # print(a.index_select(0,mask)) #方差 var = torch.var(a, 1) var2 = torch.var_mean(a, 1) #开根号倒数 var3 = torch.rsqrt(var) b = torch.nn.functional.softmax(a, 1) # b = tensor([[4, 3, 2, 1, ], [1, 2, 3, 4]]) # a = a.cuda() # b = b.cuda() # print(a) # a_soft = F.softmax(a, 1) # print(a_soft)
def forward(self, x): return x + torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
def step(self, steps=1): """ Take a projection step. Arguments: steps (int): Number of steps to take. If this exceeds the remaining steps of the projection that amount of steps is taken instead. Default value is 1. """ self._check_job() remaining_steps = self._job.num_steps - self._job.current_step if not remaining_steps > 0: warnings.warn( 'Trying to take a projection step after the ' + \ 'final projection iteration has been completed.' ) if steps < 0: steps = remaining_steps steps = min(remaining_steps, steps) if not steps > 0: return for _ in range(steps): if self._job.current_step >= self._job.num_steps: break # Hyperparameters. t = self._job.current_step / self._job.num_steps noise_strength = self._dlatent_std * self._job.initial_noise_factor \ * max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2 lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length) learning_rate = self._job.initial_learning_rate * lr_ramp for param_group in self._job.opt.param_groups: param_group['lr'] = learning_rate dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_( ) output = self.G_synthesis(dlatents) assert output.size() == self._job.target.size(), \ 'target size {} does not fit output size {} of generator'.format( target.size(), output.size()) output_scaled = self._scale_for_lpips(output) # Main loss: LPIPS distance of output and target lpips_distance = torch.mean( self.lpips_model(output_scaled, self._job.target_scaled)) # Calculate noise regularization loss reg_loss = 0 for p in self._job.noise_params: size = min(p.size()[2:]) dim = p.dim() - 2 while True: reg_loss += torch.mean( (p * p.roll(shifts=[1] * dim, dims=list(range(2, 2 + dim))))**2) if size <= 8: break p = F.interpolate(p, scale_factor=0.5, mode='area') size = size // 2 # Combine loss, backward and update params loss = lpips_distance + self._job.regularize_noise_weight * reg_loss self._job.opt.zero_grad() loss.backward() self._job.opt.step() # Normalize noise values for p in self._job.noise_params: with torch.no_grad(): p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True) p_rstd = torch.rsqrt( torch.mean((p - p_mean)**2, dim=list(range(1, p.dim())), keepdim=True) + 1e-8) p.data = (p.data - p_mean) * p_rstd self._job.current_step += 1 if self._job.verbose: self._job.value_tracker.add('loss', float(loss)) self._job.value_tracker.add('lpips_distance', float(lpips_distance)) self._job.value_tracker.add('noise_reg', float(reg_loss)) self._job.value_tracker.add('lr', learning_rate, beta=0) self._job.progress.write(self._job.verbose_prefix, str(self._job.value_tracker)) if self._job.current_step >= self._job.num_steps: self._job.progress.close()
def forward(self, x): mean = torch.mean(x * x, 1, keepdim=True) dom = torch.rsqrt(mean + self.eps) x = x * dom return x
# real torch.randn(4, dtype=torch.cfloat).real # reciprocal torch.reciprocal(a) # remainder torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5) # round torch.round(a) # rsqrt torch.rsqrt(a) # sigmoid torch.sigmoid(a) # sign torch.sign(torch.tensor([0.7, -1.2, 0., 2.3])) # sgn torch.tensor([3 + 4j, 7 - 24j, 0, 1 + 2j]).sgn() # signbit torch.signbit(torch.tensor([0.7, -1.2, 0., 2.3])) # sin torch.sin(a)
def forward(self, x): mean = x.mean(-1, keepdim=True) s = (x - mean).pow(2).mean(-1, keepdim=True) x = (x - mean) * torch.rsqrt(s + self.eps) return self.scale * x + self.shift
grad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach() out_sbn_c_last = sbn_c_last(inp_sbn_c_last) out_sbn_c_last.backward(grad_sbn_c_last) sbn_result = True sbn_result_c_last = True bn_result = True sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result #sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t) out_r = weight_r * ( inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) + bias_r sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result compare("comparing bn output: ", out_bn, out_r, error) grad_output_t = type_tensor(grad) grad_output_r = ref_tensor( grad.transpose(1, 0, 2, 3).reshape(feature_size, -1)) grad_output2_r = ref_tensor(grad) grad_bias_r = grad_output_r.sum(1) grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) * grad_output2_r).transpose(1, 0).contiguous().view( feature_size, -1).sum(1)
def normalize_adj(mx): rowsum = mx.sum(1, keepdim=False) r_inv_sqrt = torch.rsqrt(rowsum) r_inv_sqrt[torch.isinf(r_inv_sqrt)] = 0. r_mat_inv_sqrt = torch.diag(r_inv_sqrt) return mx.mm(r_mat_inv_sqrt).transpose(0, 1).mm(r_mat_inv_sqrt)
def test(parser, visualisation=None): parser = updateParser(parser) kwargs = vars(parser.parse_args()) # Parameters name = getVal(kwargs, "name", None) if name is None: raise ValueError("You need to input a name") module = getVal(kwargs, "module", None) if module is None: raise ValueError("You need to input a module") imgPath = getVal(kwargs, "inputImage", None) if imgPath is None: raise ValueError("You need to input an image path") scale = getVal(kwargs, "scale", None) iter = getVal(kwargs, "iter", None) nRuns = getVal(kwargs, "nRuns", 1) checkPointDir = os.path.join(kwargs["dir"], name) checkpointData = getLastCheckPoint(checkPointDir, name, scale=scale, iter=iter) weights = getVal(kwargs, 'weights', None) if checkpointData is None: raise FileNotFoundError( "No checkpoint found for model " + str(name) + " at directory " + str(checkPointDir) + 'cwd=' + str(os.getcwd())) modelConfig, pathModel, _ = checkpointData keysLabels = None with open(modelConfig, 'rb') as file: keysLabels = json.load(file)["attribKeysOrder"] if keysLabels is None: keysLabels = {} packageStr, modelTypeStr = getNameAndPackage(module) modelType = loadmodule(packageStr, modelTypeStr) visualizer = GANVisualizer( pathModel, modelConfig, modelType, visualisation) # Load the image targetSize = visualizer.model.getSize() baseTransform = standardTransform(targetSize) img = pil_loader(imgPath) input = baseTransform(img) input = input.view(1, input.size(0), input.size(1), input.size(2)) pathsModel = getVal(kwargs, "featureExtractor", None) featureExtractors = [] imgTransforms = [] if weights is not None: if pathsModel is None or len(pathsModel) != len(weights): raise AttributeError( "The number of weights must match the number of models") if pathsModel is not None: for path in pathsModel: if path == "id": featureExtractor = IDModule() imgTransform = IDModule() else: featureExtractor, mean, std = buildFeatureExtractor( path, resetGrad=True) imgTransform = FeatureTransform(mean, std, size=128) # None) featureExtractors.append(featureExtractor) imgTransforms.append(imgTransform) else: featureExtractors = IDModule() imgTransforms = IDModule() basePath = os.path.splitext(imgPath)[0] + "_" + kwargs['suffix'] if not os.path.isdir(basePath): os.mkdir(basePath) basePath = os.path.join(basePath, os.path.basename(basePath)) print("All results will be saved in " + basePath) outDictData = {} outPathDescent = None fullInputs = torch.cat([input for x in range(nRuns)], dim=0) if kwargs['save_descent']: outPathDescent = os.path.join( os.path.dirname(basePath), "descent") if not os.path.isdir(outPathDescent): os.mkdir(outPathDescent) img, outVectors, loss = gradientDescentOnInput(visualizer.model, fullInputs, featureExtractors, imgTransforms, visualizer=visualisation, lambdaD=kwargs['lambdaD'], nSteps=kwargs['nSteps'], weights=weights, randomSearch=kwargs['random_search'], nevergrad=kwargs['nevergrad'], lr=kwargs['learningRate'], outPathSave=outPathDescent) pathVectors = basePath + "vector.pt" torch.save(outVectors, open(pathVectors, 'wb')) path = basePath + ".jpg" visualisation.saveTensor(img, (img.size(2), img.size(3)), path) outDictData[os.path.splitext(os.path.basename(path))[0]] = \ [x.item() for x in loss] outVectors = outVectors.view(outVectors.size(0), -1) outVectors *= torch.rsqrt((outVectors**2).mean(dim=1, keepdim=True)) barycenter = outVectors.mean(dim=0) barycenter *= torch.rsqrt((barycenter**2).mean()) meanAngles = (outVectors * barycenter).mean(dim=1) meanDist = torch.sqrt(((barycenter-outVectors)**2).mean(dim=1)).mean(dim=0) outDictData["Barycenter"] = {"meanDist": meanDist.item(), "stdAngles": meanAngles.std().item(), "meanAngles": meanAngles.mean().item()} path = basePath + "_data.json" outDictData["kwargs"] = kwargs with open(path, 'w') as file: json.dump(outDictData, file, indent=2) pathVectors = basePath + "vectors.pt" torch.save(outVectors, open(pathVectors, 'wb'))
def zca_mean( inp: torch.Tensor, dim: int = 0, unbiased: bool = True, eps: float = 1e-6, return_inverse: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: r""" Computes the ZCA whitening matrix and mean vector. The output can be used with :py:meth:`~kornia.color.linear_transform` See :class:`~kornia.color.ZCAWhitening` for details. args: inp (torch.Tensor) : input data tensor dim (int): Specifies the dimension that serves as the samples dimension. Default = 0 unbiased (bool): Whether to use the unbiased estimate of the covariance matrix. Default = True eps (float) : a small number used for numerical stability. Default = 0 return_inverse (bool): Whether to return the inverse ZCA transform. shapes: - inp: :math:`(D_0,...,D_{\text{dim}},...,D_N)` is a batch of N-D tensors. - transform_matrix: :math:`(\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)` - mean_vector: :math:`(1, \Pi_{d=0,d\neq \text{dim}}^N D_d)` - inv_transform: same shape as the transform matrix returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the ZCA matrix and the mean vector. If return_inverse is set to True, then it returns the inverse ZCA matrix, otherwise it returns None. Examples: >>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32) >>> transform_matrix, mean_vector,_ = zca_mean(x) # Returns transformation matrix and data mean >>> x = torch.rand(3,20,2,2) >>> transform_matrix, mean_vector, inv_transform = zca_mean(x, dim = 1, return_inverse = True) >>> # transform_matrix.size() equals (12,12) and the mean vector.size equal (1,12) """ if not isinstance(inp, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(inp))) if not isinstance(eps, float): raise TypeError(f"eps type is not a float. Got{type(eps)}") if not isinstance(unbiased, bool): raise TypeError(f"unbiased type is not bool. Got{type(unbiased)}") if not isinstance(dim, int): raise TypeError("Argument 'dim' must be of type int. Got {}".format( type(dim))) if not isinstance(return_inverse, bool): raise TypeError( "Argument return_inverse must be of type bool {}".format( type(return_inverse))) inp_size = inp.size() if dim >= len(inp_size) or dim < -len(inp_size): raise IndexError( "Dimension out of range (expected to be in range of [{},{}], but got {}" .format(-len(inp_size), len(inp_size) - 1, dim)) if dim < 0: dim = len(inp_size) + dim feat_dims = torch.cat( [torch.arange(0, dim), torch.arange(dim + 1, len(inp_size))]) new_order: List[int] = torch.cat([torch.tensor([dim]), feat_dims]).tolist() inp_permute = inp.permute(new_order) N = inp_size[dim] feature_sizes = torch.tensor(inp_size[0:dim] + inp_size[dim + 1::]) num_features: int = int(torch.prod(feature_sizes).item()) mean: torch.Tensor = torch.mean(inp_permute, dim=0, keepdim=True) mean = mean.reshape((1, num_features)) inp_center_flat: torch.Tensor = inp_permute.reshape( (N, num_features)) - mean cov = inp_center_flat.t().mm(inp_center_flat) if unbiased: cov = cov / float(N - 1) else: cov = cov / float(N) U, S, _ = _torch_svd_cast(cov) S = S.reshape(-1, 1) S_inv_root: torch.Tensor = torch.rsqrt(S + eps) T: torch.Tensor = (U).mm(S_inv_root * U.t()) T_inv: Optional[torch.Tensor] = None if return_inverse: T_inv = (U).mm(torch.sqrt(S + eps) * U.t()) return T, mean, T_inv
def forward(self, x): return x * torch.rsqrt( torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
def instance_norm(x, eps=1e-8): """Instance normalization. """ assert len(x.shape) == 4, "shape of input should be NCHW!" x -= torch.mean(x, dim=(2, 3), keepdim=True) return x * torch.rsqrt(torch.mean(x ** 2, dim=(2, 3), keepdim=True) + eps)
def forward(self, x): tmp = torch.mul(x, x) # or x ** 2 tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon) return x * tmp1
def forward(self, x, abs_x, deg, idx): batch, channels, npoints, neighbors = x.size() # B, C, N, K ''' 1. get point features (B, C, N) ''' x_q = abs_x # B, C//2, N, 1 x_kv = x # B, C, N, K ''' 2. transform by Wq, Wk, Wv ''' q_out = self.query_conv(x_q) # B, C, N, 1 k_out = self.key_conv(x_kv) # B, C, N, K k_out_all = k_out[:, :, :, 0] # B, C, N, 1 v_out = self.value_conv(x_kv) # B, C, N, K v_out_all = v_out[:, :, :, 0] # B, C, N, 1 ''' 3. relative positional encoding ''' if self.rpe: k_out = k_out + self.rel_k # k_out : B, C, N, K / self.rel_k : C, 1, K ''' 4. multi-head attention ''' if self.scale: scaler = torch.tensor([self.out_channels / self.groups]).cuda() out = torch.rsqrt(scaler) * q_out * k_out else: out = q_out * k_out # B, C, N, K out = F.softmax(out, dim=-1) # B, C, N, K ''' 5. scoring ''' idx = idx[:, :, :, -1].unsqueeze(1).expand_as( out).cuda() # B, N, K -> B, 1, N, K -> B, C, N, K idx_scatter = torch.zeros(batch, self.out_channels, npoints, npoints, device='cuda').detach() # B, C, N, N # node-wise importance idx_scatter.scatter_(dim=3, index=idx, src=out)[0, 0, 0, :] # B,C,N,N -> B,C,N,1 score = idx_scatter.sum(dim=2, keepdim=True).transpose(2, 3).squeeze( 3) # B, C, 1, N -> B, C, N, 1 -> B, C, N idx_key, idx_salient = score.topk( k=20, dim=-1) # B, C, S | here, S : sampled global points k_out_all = torch.gather(k_out_all, 2, idx_salient).unsqueeze(2).repeat( 1, 1, npoints, 1) # B, C, S -> B, C, N, S v_out_all = torch.gather(v_out_all, 2, idx_salient).unsqueeze(2).repeat( 1, 1, npoints, 1) # B, C, S -> B, C, N, S out_all = q_out * k_out_all # B, C, N, S out_all = F.softmax(out_all, dim=-1) # B, C, N, S out_all = torch.einsum('bcns,bcns -> bcn', out_all, v_out_all) # B, C, N out_all = out_all.view(batch, -1, npoints, 1) # B, C, N, 1 # x : 6 x 3 / idx : 6 # x : B,C,N x K / idx : B,C,N,K out = torch.einsum('bcnk,bcnk -> bcn', out, v_out) # b, C, N, K out = out.view(batch, -1, npoints, 1) # b, C, N, 1 ''' 8. reshape for memory ''' #out = torch.cat([out, out_all], dim = 1) if self.layer > 0: #print("ratio : {}".format(out_all.mean()/out.mean())) out = torch.cat([out, out_all], dim=1) if self.return_kv: k_out = k_out.view(batch, -1, npoints, neighbors, 1) v_out = v_out.view(batch, -1, npoints, neighbors, 1) return out, k_out, v_out else: return out
def forward(self, x): x2 = x**2 temp = torch.mean(x2, dim=1, keepdim=True) return x * torch.rsqrt(temp + self.epsilon)
def test_rqrt(x, y): c = torch.rsqrt(torch.add(x, y)) return c
def forward(self, x, style): """Forward function. Args: x ([Tensor): Input features with shape of (N, C, H, W). style (Tensor): Style latent with shape of (N, C). Returns: Tensor: Output feature with shape of (N, C, H, W). """ n, c, h, w = x.shape # process style code style = self.style_modulation(style).view(n, 1, c, 1, 1) + self.style_bias # combine weight and style weight = self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) weight = weight * demod.view(n, self.out_channels, 1, 1, 1) weight = weight.view(n * self.out_channels, c, self.kernel_size, self.kernel_size) if self.upsample and not self.deconv2conv: x = x.reshape(1, n * c, h, w) weight = weight.view(n, self.out_channels, c, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(n * c, self.out_channels, self.kernel_size, self.kernel_size) x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n) x = x.reshape(n, self.out_channels, *x.shape[-2:]) x = self.blur(x) elif self.upsample and self.deconv2conv: if self.up_after_conv: x = x.reshape(1, n * c, h, w) x = conv2d(x, weight, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[2:4]) if self.with_interp_pad: h_, w_ = x.shape[-2:] up_cfg_ = deepcopy(self.up_config) up_scale = up_cfg_.pop('scale_factor') size_ = (h_ * up_scale + self.interp_pad, w_ * up_scale + self.interp_pad) x = F.interpolate(x, size=size_, **up_cfg_) else: x = F.interpolate(x, **self.up_config) if not self.up_after_conv: h_, w_ = x.shape[-2:] x = x.view(1, n * c, h_, w_) x = conv2d(x, weight, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[2:4]) elif self.downsample: x = self.blur(x) x = x.view(1, n * self.in_channels, *x.shape[-2:]) x = conv2d(x, weight, stride=2, padding=0, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) else: x = x.view(1, n * c, h, w) x = conv2d(x, weight, stride=1, padding=self.padding, groups=n) x = x.view(n, self.out_channels, *x.shape[-2:]) return x
def forward(self, input, style, coords=None, calc_flops=False): batch, in_channel, height, width = input.shape if calc_flops: flops = self.get_flops(input, style) else: flops = 0 # Special case for spatially-shaped style # Here, we early justify whether the whole feature uses the same style. # If that's the case, we simply use the same style, otherwise, it will use another slower logic. if (style is not None) and (style.ndim == 4): mean_style = style.mean([2, 3], keepdim=True) is_mono_style = ((style - mean_style) < 1e-8).all() if is_mono_style: style = mean_style.squeeze() # Regular forward if style.ndim == 2: style = self.modulation(style).view(batch, 1, in_channel, 1, 1) # (1, ) * (1, out_ch, in_ch, k, k) * (B, 1, in_ch, 1, 1) # => (B, out_ch, in_ch, k, k) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size) weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) if self.no_zero_pad: out = out[:, :, 1:-1, 1: -1] # Clipping head and tail, which involves zero padding _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: assert (not self.training), \ "Only accepts spatially-shaped global-latent for testing-time manipulation!" assert (style.ndim == 4), \ "Only considered BxCxHxW case, but got shape {}".format(style.shape) # For simplicity (and laziness), we sometimes feed spatial latents # that are larger than the input, center-crop for such kind of cases. style = self._auto_shape_align(source=style, target=input) # [Note] # Original (lossy expression): input * (style * weight) # What we equivalently do here (still lossy): (input * style) * weight sb, sc, sh, sw = style.shape flat_style = style.permute(0, 2, 3, 1).reshape(-1, sc) # (BxHxW, C) style_mod = self.modulation(flat_style) # (BxHxW, C) style_mod = style_mod.view(sb, sh, sw, self.in_channel).permute( 0, 3, 1, 2) # (B, C, H, W) input_st = (style_mod * input) # (B, C, H, W) weight = self.scale * self.weight if self.demodulate: # [Hubert] # This will be an estimation if spatilly fused styles are different. # In practice, the interpolation of styles do not (numerically) change drastically, so the approximation here is invisible. """ # This is the implementation we shown in the paper Appendix, the for-loop is slow. # But this version surely allocates a constant amount of memory. for i in range(sh): for j in range(sw): style_expand_s = style_mod[:, :, i, j].view(sb, 1, self.in_channel, 1, 1) # shape: (B, 1, in_ch, 1, 1) simulated_weight_s = weight * style_expand_s # shape: (B, out_ch, in_ch, k, k) demod_s[:, :, i, j] = torch.rsqrt(simulated_weight_s.pow(2).sum([2, 3, 4]) + 1e-8) # shape: (B, out_ch) """ """ Logically equivalent version, omits one for-loop by batching one spatial dimension. """ demod = torch.zeros(sb, self.out_channel, sh, sw).to(style.device) for i in range(sh): style_expand = style_mod[:, :, i, :].view( sb, 1, self.in_channel, sw).pow(2) # shape: (B, 1, in_ch, W) weight_expand = weight.pow(2).sum([3, 4]).unsqueeze( -1) # shape: (B, out_ch, in_ch, 1) simulated_weight = weight_expand * style_expand # shape: (B, out_ch, in_ch, W) demod[:, :, i, :] = torch.rsqrt(simulated_weight.sum(2) + 1e-8) # shape: (B, out_ch, W) """ # An even faster version that batches both height and width dimension, but allocates too much memory that is impractical in reality. # For instance, it allocates 40GB memory with shape (8, 512, 128, 3, 3, 31, 31). style_expand = style_mod.view(sb, 1, self.in_channel, 1, 1, sh, sw) # (B, 1 in_ch, 1, 1, H, W) weight_expand = weight.unsqueeze(5).unsqueeze(6) # (B, out_ch, in_ch, k, k, 1, 1) simulated_weight = weight_expand * style_expand # shape: (B, out_ch, in_ch, k, k, H, W) demod = torch.rsqrt(simulated_weight.pow(2).sum([2, 3, 4]) + 1e-8) # shape: (B, out_ch, H, W) """ """ # Just FYI. If you use the mean style over the patch, it creates blocky artifacts mean_style = style_mod.mean([2,3]).view(sb, 1, self.in_channel, 1, 1) simulated_weight_ = weight * mean_style # shape: (B, out_ch, in_ch, k, k) demod_ = torch.rsqrt(simulated_weight_.pow(2).sum([2, 3, 4]) + 1e-8) demod_ = demod_.unsqueeze(2).unsqueeze(3) """ weight = weight.view(self.out_channel, in_channel, self.kernel_size, self.kernel_size) if self.upsample: weight = weight.transpose(0, 1).contiguous() out = F.conv_transpose2d(input_st, weight, padding=0, stride=2, groups=1) out = out[:, :, 1:-1, 1: -1] # Clipping head and tail, which involves zero padding _, _, height, width = out.shape if self.demodulate: demod = F.interpolate(demod, size=(height, width), mode="bilinear", align_corners=True) out = out * demod out = self.blur(out) elif self.downsample: input_st = self.blur(input_st) out = F.conv2d(input_st, weight, padding=self.padding, stride=2, groups=1) if self.demodulate: raise NotImplementedError("Unused, not implemented!") out = out * demod else: out = F.conv2d(input_st, weight, padding=self.padding, groups=1) if self.demodulate and (self.padding == 0): demod = demod[:, :, self.dirty_rm_size[0]:-self.dirty_rm_size[0], self.dirty_rm_size[1]:-self.dirty_rm_size[1]] out = out * demod out = out.contiguous() # Don't know where causes discontiguity. return out, flops