def CowMix(X, target, model): """ Mezcla las imagenes dentro de un mismo lote usando un mascara de formas irregulares. """ B, _, H, W = X.shape #Proporcion de pixeles que se remplazan p = torch.rand(B, 1, 1, 1, device=X.device) #Tamaño de las marcas de la mascara r = torch.randint(4, 16, (1, ), device=X.device) mask = torch.randn(B, 1, r, r, device=X.device) mask = F.interpolate(mask, size=(H, W), mode="bilinear", align_corners=False) mean = mask.mean(dim=(1, 2, 3), keepdim=True) std = mask.std(dim=(1, 2, 3), keepdim=True) tao = mean + 1.4 * std * torch.erfinv(2 * p - 1) mask = mask > tao idx = torch.randperm(B) X = mask * X + torch.logical_not(mask) * X[idx] out = model(X) p = p.squeeze() loss = ((1 - p) * F.cross_entropy(out, target, reduction='none') + p * F.cross_entropy(out, target[idx], reduction='none')) return loss.mean(), out
def draw( self, n: int = 1, out: Optional[Tensor] = None, dtype: torch.dtype = torch.float ) -> Optional[Tensor]: r"""Draw `n` qMC samples from the standard Normal. Args: n: The number of samples to draw. out: An option output tensor. If provided, draws are put into this tensor, and the function returns None. dtype: The desired torch data type (ignored if `out` is provided). Returns: A `n x d` tensor of samples if `out=None` and `None` otherwise. """ # get base samples samples = self._sobol_engine.draw(n, dtype=dtype) if self._inv_transform: # apply inverse transform (values to close to 0/1 result in inf values) v = 0.5 + (1 - 1e-10) * (samples - 0.5) samples_tf = torch.erfinv(2 * v - 1) * math.sqrt(2) else: # apply Box-Muller transform (note: [1] indexes starting from 1) even = torch.arange(0, samples.shape[-1], 2) Rs = (-2 * torch.log(samples[:, even])).sqrt() thetas = 2 * math.pi * samples[:, 1 + even] cos = torch.cos(thetas) sin = torch.sin(thetas) samples_tf = torch.stack([Rs * cos, Rs * sin], -1).reshape(n, -1) # make sure we only return the number of dimension requested samples_tf = samples_tf[:, : self._d] if out is None: return samples_tf else: out.copy_(samples_tf)
def draw(self, n: int = 1, out: Optional[Tensor] = None, dtype: torch.dtype = torch.float) -> Optional[Tensor]: r"""Draw `n` qMC samples from the standard Normal. Args: n: The number of samples to draw. out: An option output tensor. If provided, draws are put into this tensor, and the function returns None. dtype: The desired torch data type (ignored if `out` is provided). Returns: A `n x d` tensor of samples if `out=None` and `None` otherwise. """ # get base samples samples = self._sobol_engine.draw(n, dtype=dtype) if self._inv_transform: # apply inverse transform (values to close to 0/1 result in inf values) v = 0.5 + (1 - 1e-10) * (samples - 0.5) samples_tf = torch.erfinv(2 * v - 1) * math.sqrt(2) else: # apply Box-Muller transform (note: [1] indexes starting from 1) even = torch.arange(0, samples.shape[-1], 2) Rs = (-2 * torch.log(samples[:, even])).sqrt() thetas = 2 * math.pi * samples[:, 1 + even] cos = torch.cos(thetas) sin = torch.sin(thetas) samples_tf = torch.stack([Rs * cos, Rs * sin], -1).reshape(n, -1) # make sure we only return the number of dimension requested samples_tf = samples_tf[:, :self._d] if out is None: return samples_tf else: out.copy_(samples_tf)
def sample_truncated_normal(shape, mu, sigma, a, b): ''' Pytorch implementation of truncated normal distribution sampler Parameters: ---------- @param numpy array or list - shape : size should be (popsize x sol_dim) @param numpy array or list - mu, sigma : size should be (sol_dim) @param tensor - a, b : lower bound and upper bound of sampling range, size should be (sol_dim) Return: ---------- @param tensor - x : size should be (popsize x sol_dim) ''' uniform = torch.rand(shape) normal = torch.distributions.normal.Normal(0, 1) alpha = (a - mu) / sigma beta = (b - mu) / sigma alpha_normal_cdf = normal.cdf(alpha) p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform p = p.numpy() one = np.array(1, dtype=p.dtype) epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype) v = np.clip(2 * p - 1, -one + epsilon, one - epsilon) x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v)) x = torch.clamp(x, a[0], b[0]) return x
def SWD_prepare(Npercentile=100, device=torch.device("cuda:0"), gaussian=True): start = 50 / Npercentile end = 100 - start q = torch.linspace(start, end, Npercentile, device=device) if gaussian: pg = 2**0.5 * torch.erfinv(2 * q / 100 - 1) return q, pg else: return q
def manifold_sample(self, n_samples): n = int(math.sqrt(n_samples)) xy = torch.zeros(n_samples, 2) xy[:, 0] = torch.arange(0.01, n, 1 / n) % 1 xy[:, 1] = (torch.arange(0.01, n_samples, 1) / n).float() / n z = torch.erfinv(2 * xy - 1) * math.sqrt(2) with torch.no_grad(): mean = self.decoder(z) return mean
def erfinv(t): """ Element-wise inverse error function computed using cross-approximation; see PyTorch's `erfinv()`. :param t: input :class:`Tensor` :return: a :class:`Tensor` """ return tn.cross(lambda x: torch.erfinv(x), tensors=t, verbose=False)
def sample(self): alpha = (self.a - self.mu) / self.sigma beta = (self.b - self.mu) / self.sigma alpha_normal_cdf = self.normal.cdf(alpha) p = alpha_normal_cdf + (self.normal.cdf(beta) - alpha_normal_cdf) * torch.cuda.FloatTensor(self.mu.shape).uniform_() v = (2 * p - 1).clamp(-1 + self.eps, 1 - self.eps) x = self.mu + self.sigma * 1.4142135623730951 * torch.erfinv(v) return x.clamp(self.a, self.b)
def inv_gaussian_mixture_cdf(z, m, v): """ UNDER CONSTRUCTION: HAVEN'T FOUND SOLUTION TO INVERT MIXTURE OF GAUSSIAN CDFS :param z: [batch, k_mixture] :param m: [batch, k_mixture] :param v: [batch, k_mixture] :return: inverse gaussian CDF of z [batch] """ probs = torch.sqrt(2 * v) * torch.erfinv(2 * z - 1) + m return probs.sum(1)
def manifold_plot_n_save(model): with torch.no_grad(): num_row = 20 grid = torch.linspace(0, 1, num_row) samples = [ torch.erfinv(2 * torch.tensor([x, y]) - 1) * np.sqrt(2) for x in grid for y in grid ] samples = torch.stack(samples).to(model.device) manifold = model.decoder(samples).view(-1, 1, 28, 28) image = make_grid(manifold, nrow=num_row) plt.imsave("manifold.png", image.cpu().numpy().transpose(1, 2, 0))
def plot_manifold(model): with torch.no_grad(): row_count = 20 grid = torch.linspace(0, 1, row_count) samples = [ torch.erfinv(2 * torch.tensor([x, y]) - 1) * math.sqrt(2) for x in grid for y in grid ] samples = torch.stack(samples).to(model.device) manifold = model.decoder(samples).view(-1, 1, 28, 28) img = make_grid(manifold, nrow=row_count) plt.imsave("manifold.png", img.cpu().numpy().transpose(1, 2, 0))
def main(): data = bmnist()[:2] # ignore test split model = VAE(z_dim=ARGS.zdim) optimizer = torch.optim.Adam(model.parameters()) train_curve, val_curve = [], [] for epoch in range(ARGS.epochs): elbos = run_epoch(model, data, optimizer) train_elbo, val_elbo = elbos train_curve.append(train_elbo) val_curve.append(val_elbo) print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M")}] Epoch {epoch:02d}, train elbo: {train_elbo:07f}, val_elbo: {val_elbo:07f}') # -------------------------------------------------------------------- # Add functionality to plot samples from model during training. # You can use the make_grid functioanlity that is already imported. # -------------------------------------------------------------------- # plot the samples sampled_ims, im_means = model.sample(16) save_samples_plot(sampled_ims, im_means, model, epoch, ARGS.zdim) # save the results results = { 'train elbo': train_curve, 'val elbo': val_curve } results_df = df.from_dict(results) results_path = RESULTS_DIR / f'{model.__class__.__name__}_results.csv' results_df.to_csv(results_path, sep=';', encoding='utf-8', index=False) # save the model model_path = RESULTS_DIR / f'{model.__class__.__name__}_model.pt' torch.save(model.state_dict(), model_path) # -------------------------------------------------------------------- # Add functionality to plot plot the learned data manifold after # if required (i.e., if zdim == 2). You can use the make_grid # functionality that is already imported. # -------------------------------------------------------------------- if ARGS.zdim == 2: # create data manifold n_rows = 20 x = y = np.linspace(0, 1, n_rows) samples = [torch.erfinv(2 * torch.tensor([x, y], device='cpu') - 1) * np.sqrt(2) for x in x for y in y] samples = torch.stack(samples) manifold = model.decoder(samples).view(-1, 1, 28, 28) grid = make_grid(manifold, nrow=n_rows) plt.imsave(RESULTS_DIR / 'VAE_manifold.png', grid.cpu().numpy().transpose(1, 2, 0)) plot_path = RESULTS_DIR / f'{model.__class__.__name__}_elbo.png' save_elbo_plot(train_curve, val_curve, plot_path)
def icdf(self, value): ''' TODO: Fixed inf bug by a dirty hack! We replace Inf of torch.erfinv by torch.erfinv(torch.tensor(0.99999997)) We should implement our own icdf to fix the bug properly. ''' if self._validate_args: self._validate_sample(value) res = self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) res = torch.clamp(res, -5.3, 5.3) #if torch.any(res == float('inf')) or torch.any(res == float('-inf')): import ipdb; ipdb.set_trace() return res
def Gaussian_ppf(Nsample, weight=None, device=torch.device("cuda:0")): if weight is None: start = 50 / Nsample end = 100 - start q = torch.linspace(start, end, Nsample, device=device, dtype=torch.float64) else: q = torch.cumsum(weight, dim=1, dtype=torch.float64) q = q - 0.5 * weight pg = 2**0.5 * torch.erfinv(2 * q / 100 - 1).to(torch.get_default_dtype()) return pg
def _cow_mask(input_size, lam, config_mix): """ Compute masks for CowMix lam is overridden by Cowmask's parameters https://github.com/google-research/google-research/tree/master/milking_cowmask/masking """ # Default CowMix config misc.ifnotfound_update( config_mix, { "cow_p_max": 0.8, "cow_p_min": 0.2, "cow_sigma_max": 16.0, "cow_sigma_min": 4.0, } ) # Get lam ratio for the mask p_max = config_mix["cow_p_max"] p_min = config_mix["cow_p_min"] proba = torch.tensor(p_min + np.random.rand(1) * (p_max - p_min)) # Get sigma for Gaussian kernel sigma_max = config_mix["cow_sigma_max"] sigma_min = config_mix["cow_sigma_min"] sigma = np.exp(math.log(sigma_min) + np.random.rand(1) * (math.log(sigma_max) - math.log(sigma_min))) # Compute Gaussian kernel gaussian_kernel = _gaussian_blur_kernel(sigma, sigma_max) gaussian_kernel = gaussian_kernel.unsqueeze(-1) gaussian_kernel = gaussian_kernel.T * gaussian_kernel # Shape it as a proper kernel gaussian_kernel = gaussian_kernel.unsqueeze(0).unsqueeze(0) gaussian_kernel = gaussian_kernel.repeat((input_size[0], 1, 1, 1)).float() noise = torch.randn(1, 1, input_size[1], input_size[2]) blurred_noise = F.conv2d(noise, gaussian_kernel, padding = gaussian_kernel.size()[-1]//2) noise_mean = blurred_noise.mean() noise_std = blurred_noise.std() # Get thresholded cowmask threshold_stat = noise_mean + math.sqrt(2) * torch.erfinv(2*proba - 1) * noise_std mask = blurred_noise <= threshold_stat return mask.squeeze(0).float()
def main(): data = bmnist()[:2] # ignore test split model = VAE(z_dim=ARGS.zdim).to(DEVICE) optimizer = torch.optim.Adam(model.parameters()) os.makedirs('images_vae', exist_ok=True) train_curve, val_curve = [], [] for epoch in range(ARGS.epochs): elbos = run_epoch(model, data, optimizer) train_elbo, val_elbo = elbos train_curve.append(train_elbo) val_curve.append(val_elbo) print(f"[Epoch {epoch}] train elbo: {train_elbo} val_elbo: {val_elbo}") # -------------------------------------------------------------------- # Add functionality to plot samples from model during training. # You can use the make_grid functioanlity that is already imported. # -------------------------------------------------------------------- ims_per_row = 5 sampled_ims, _ = model.sample(ims_per_row * ims_per_row) grid = make_grid(sampled_ims, nrow=ims_per_row) save_image(grid, 'images_vae/epoch{}_{}z.png'.format(epoch, ARGS.zdim), normalize=True) torch.save(model.state_dict(), "models/VAE_{}epochs_{}z.pt".format(ARGS.epochs, ARGS.zdim)) # -------------------------------------------------------------------- # Add functionality to plot plot the learned data manifold after # if required (i.e., if zdim == 2). You can use the make_grid # functionality that is already imported. # -------------------------------------------------------------------- if ARGS.zdim == 2: with torch.no_grad(): steps = 20 density_points = torch.linspace(0, 1, steps) # Basically use adaptation of torch.distributions.icdf here for manifold z's z_tensors = [ torch.erfinv(2 * torch.tensor([x, y]) - 1) * np.sqrt(2) for x in density_points for y in density_points ] z = torch.stack(z_tensors).to(DEVICE) _, manifold = model.sample(1, z) image = make_grid(manifold, nrow=steps) save_image(image, "images_vae/manifold.pdf") save_elbo_plot(train_curve, val_curve, 'elbo.pdf')
def parameterized_truncated_normal(uniform, mu, sigma, a, b): normal = torch.distributions.normal.Normal(0, 1) alpha = (a - mu) / sigma beta = (b - mu) / sigma alpha_normal_cdf = normal.cdf(alpha) p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform p = p.numpy() one = np.array(1, dtype=p.dtype) epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype) v = np.clip(2 * p - 1, -one + epsilon, one - epsilon) x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v)) x = torch.clamp(x, a, b) return x
def compute_all_i_cdfs(this_means, this_stds, sorted_weights, directions): transformed_means, transformed_stds = transform_gaussian_by_dirs( this_means, th.abs(this_stds), directions) n_virtual_samples = th.sum(sorted_weights[:, 0]) start = 1 / (2 * n_virtual_samples) wanted_sum = 1 - (2 / (n_virtual_samples)) probs = sorted_weights * wanted_sum / n_virtual_samples empirical_cdf = start + th.cumsum(probs, dim=0) # see https://en.wikipedia.orsorted_softmaxedg/wiki/Normal_distribution -> Quantile function sqrt_2 = th.autograd.Variable(th.FloatTensor([np.sqrt(2.0)])) sqrt_2, empirical_cdf = ensure_on_same_device(sqrt_2, empirical_cdf) i_cdf = sqrt_2 * th.erfinv(2 * empirical_cdf - 1) i_cdf = i_cdf.squeeze() all_i_cdfs = i_cdf * transformed_stds.t() + transformed_means.t() return all_i_cdfs
def standard_normal_icdf(p): """ Compute the componentwise inverse cdf of a gaussian with 0 mean and I covariance matrix. The clamp is there to avoid numerical instability Parameters ---------- p: float(s) Probability value in [0,1] Returns ------- float(s): quartile of p """ p = torch.clamp(p, min=0 + tollerance, max=1 - tollerance) return torch.erfinv(2 * p - 1) * math.sqrt(2)
def sample_manifold(self, n_samples, start, stop): xy = torch.linspace(start, stop, n_samples) # we can't use torch.distributions. Below is a workaround to get a similar function # https://pytorch.org/docs/stable/torch.html#torch.erf # as the inverse error has a different normalizing parameter than the normal distribution, # we recalculate is a bit grid = [ torch.erfinv(2 * torch.tensor([x, y], device='cpu') - 1) * math.sqrt(2) for x in xy for y in xy ] zs = torch.stack(grid) # create all the image means constructions = self.decoder.forward(zs) im_means = constructions.reshape( (constructions.size()[0], 1, int(np.sqrt(constructions.size()[1])), int(np.sqrt(constructions.size()[1])))) return im_means
def erfinv(self, tensor_in): """ The inverse of the error function of complex argument. Example: >>> import pyhf >>> pyhf.set_backend("pytorch") >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.]) >>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a)) tensor([-2.0000, -1.0000, 0.0000, 1.0000, 2.0000]) Args: tensor_in (:obj:`tensor`): The input tensor object Returns: PyTorch Tensor: The values of the inverse of the error function at the given points. """ return torch.erfinv(tensor_in)
def parameterized_truncated_normal(uniform, mu, sigma, a, b): r"""Implements the sampling of truncated normal distribution using the inversed cumulative distribution function (CDF) method. .. _Truncated Normal\: normal distribution in which the range of definition is made finite at one or both ends of the interval. https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf """ normal = torch.distributions.normal.Normal(0, 1) alpha, beta = (a - mu) / sigma, (b - mu) / sigma p = normal.cdf(alpha) + (normal.cdf(beta) - normal.cdf(alpha)) * uniform p = p.numpy() one = np.array(1, dtype=p.dtype) epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype) v = np.clip(2 * p - 1, -one + epsilon, one - epsilon) x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v)) x = torch.clamp(x, a, b) return x.type(torch.get_default_dtype())
def icdf(self, value): if self._validate_args: self._validate_sample(value) return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
def normal_icdf(value, loc, scale): return loc + scale * torch.erfinv(2 * value - 1) * (math.sqrt(2))
def forward(self, images, debug_percentile=None): assert isinstance(images, torch.Tensor) and images.ndim == 4 batch_size, num_channels, height, width = images.shape device = images.device if debug_percentile is not None: debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) # ------------------------------------- # Select parameters for pixel blitting. # ------------------------------------- # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in I_3 = torch.eye(3, device=device) G_inv = I_3 # Apply x-flip with probability (xflip * strength). if self.xflip > 0: i = torch.floor(torch.rand([batch_size], device=device) * 2) i = torch.where( torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 2)) G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) # Apply 90 degree rotations with probability (rotate90 * strength). if self.rotate90 > 0: i = torch.floor(torch.rand([batch_size], device=device) * 4) i = torch.where( torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 4)) G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) # Apply integer translation with probability (xint * strength). if self.xint > 0: t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max t = torch.where( torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) if debug_percentile is not None: t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) G_inv = G_inv @ translate2d_inv(torch.round(t[:, 0] * width), torch.round(t[:, 1] * height)) # -------------------------------------------------------- # Select parameters for general geometric transformations. # -------------------------------------------------------- # Apply isotropic scaling with probability (scale * strength). if self.scale > 0: s = torch.exp2( torch.randn([batch_size], device=device) * self.scale_std) s = torch.where( torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like( s, torch.exp2( torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) G_inv = G_inv @ scale2d_inv(s, s) # Apply pre-rotation with probability p_rot. p_rot = 1 - torch.sqrt( (1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p if self.rotate > 0: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max theta = torch.where( torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. # Apply anisotropic scaling with probability (aniso * strength). if self.aniso > 0: s = torch.exp2( torch.randn([batch_size], device=device) * self.aniso_std) s = torch.where( torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like( s, torch.exp2( torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) G_inv = G_inv @ scale2d_inv(s, 1 / s) # Apply post-rotation with probability p_rot. if self.rotate > 0: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max theta = torch.where( torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.zeros_like(theta) G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. # Apply fractional translation with probability (xfrac * strength). if self.xfrac > 0: t = torch.randn([batch_size, 2], device=device) * self.xfrac_std t = torch.where( torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) if debug_percentile is not None: t = torch.full_like( t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height) # ---------------------------------- # Execute geometric transformations. # ---------------------------------- # Execute if the transform is not identity. if G_inv is not I_3: # Calculate padding. cx = (width - 1) / 2 cy = (height - 1) / 2 cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] cp = G_inv @ cp.t() # [batch, xyz, idx] Hz_pad = self.Hz_geom.shape[0] // 4 margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] margin = margin + misc.constant( [Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) margin = margin.max(misc.constant([0, 0] * 2, device=device)) margin = margin.min( misc.constant([width - 1, height - 1] * 2, device=device)) mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) # Pad image and adjust origin. images = torch.nn.functional.pad(input=images, pad=[mx0, mx1, my0, my1], mode='reflect') G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv # Upsample. images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv( 2, 2, device=device) G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv( -0.5, -0.5, device=device) # Execute transformation. shape = [ batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2 ] G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv( 2 / shape[3], 2 / shape[2], device=device) grid = torch.nn.functional.affine_grid(theta=G_inv[:, :2, :], size=shape, align_corners=False) images = grid_sample_gradfix.grid_sample(images, grid) # Downsample and crop. images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad * 2, flip_filter=True) # -------------------------------------------- # Select parameters for color transformations. # -------------------------------------------- # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out I_4 = torch.eye(4, device=device) C = I_4 # Apply brightness with probability (brightness * strength). if self.brightness > 0: b = torch.randn([batch_size], device=device) * self.brightness_std b = torch.where( torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) if debug_percentile is not None: b = torch.full_like( b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) C = translate3d(b, b, b) @ C # Apply contrast with probability (contrast * strength). if self.contrast > 0: c = torch.exp2( torch.randn([batch_size], device=device) * self.contrast_std) c = torch.where( torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) if debug_percentile is not None: c = torch.full_like( c, torch.exp2( torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) C = scale3d(c, c, c) @ C # Apply luma flip with probability (lumaflip * strength). v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. if self.lumaflip > 0: i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) i = torch.where( torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i)) if debug_percentile is not None: i = torch.full_like(i, torch.floor(debug_percentile * 2)) C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. # Apply hue rotation with probability (hue * strength). if self.hue > 0 and num_channels > 1: theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max theta = torch.where( torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) if debug_percentile is not None: theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) C = rotate3d(v, theta) @ C # Rotate around v. # Apply saturation with probability (saturation * strength). if self.saturation > 0 and num_channels > 1: s = torch.exp2( torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) s = torch.where( torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) if debug_percentile is not None: s = torch.full_like( s, torch.exp2( torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C # ------------------------------ # Execute color transformations. # ------------------------------ # Execute if the transform is not identity. if C is not I_4: images = images.reshape([batch_size, num_channels, height * width]) if num_channels == 4: alpha = images[:, 3, :].unsqueeze(dim=1) # [batch_size, 1, ...] rgb = C[:, :3, : 3] @ images[:, :3, :] + C[:, :3, 3:] # [batch_size, 3, ...] images = torch.cat([rgb, alpha], dim=1) # [batch_size, 4, ...] elif num_channels == 3: images = C[:, :3, :3] @ images + C[:, :3, 3:] elif num_channels == 1: C = C[:, :3, :].mean(dim=1, keepdims=True) images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] else: raise ValueError( 'Image must be RGBA (4 channels), RGB (3 channels) or L (1 channel)' ) images = images.reshape([batch_size, num_channels, height, width]) # ---------------------- # Image-space filtering. # ---------------------- if self.imgfilter > 0: num_bands = self.Hz_fbank.shape[0] assert len(self.imgfilter_bands) == num_bands expected_power = misc.constant( np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). # Apply amplification for each band with probability (imgfilter * strength * band_strength). g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). for i, band_strength in enumerate(self.imgfilter_bands): t_i = torch.exp2( torch.randn([batch_size], device=device) * self.imgfilter_std) t_i = torch.where( torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) if debug_percentile is not None: t_i = torch.full_like( t_i, torch.exp2( torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std) ) if band_strength > 0 else torch.ones_like(t_i) t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. t[:, i] = t_i # Replace i'th element. t = t / (expected_power * t.square()).sum( dim=-1, keepdims=True).sqrt() # Normalize power. g = g * t # Accumulate into global gain. # Construct combined amplification filter. Hz_prime = g @ self.Hz_fbank # [batch, tap] Hz_prime = Hz_prime.unsqueeze(1).repeat( [1, num_channels, 1]) # [batch, channels, tap] Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] # Apply filter. p = self.Hz_fbank.shape[1] // 2 images = images.reshape( [1, batch_size * num_channels, height, width]) images = torch.nn.functional.pad(input=images, pad=[p, p, p, p], mode='reflect') images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size * num_channels) images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size * num_channels) images = images.reshape([batch_size, num_channels, height, width]) # ------------------------ # Image-space corruptions. # ------------------------ # Apply additive RGB noise with probability (noise * strength). if self.noise > 0: sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std sigma = torch.where( torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma)) if debug_percentile is not None: sigma = torch.full_like( sigma, torch.erfinv(debug_percentile) * self.noise_std) images = images + torch.randn( [batch_size, num_channels, height, width], device=device) * sigma # Apply cutout with probability (cutout * strength). if self.cutout > 0: size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device) size = torch.where( torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size)) center = torch.rand([batch_size, 2, 1, 1, 1], device=device) if debug_percentile is not None: size = torch.full_like(size, self.cutout_size) center = torch.full_like(center, debug_percentile) coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2) mask = torch.logical_or(mask_x, mask_y).to(torch.float32) images = images * mask return images
def _standard_normal_quantile(u): # Ref: https://en.wikipedia.org/wiki/Normal_distribution return math.sqrt(2) * torch.erfinv(2 * u - 1)
def icdf(self, value): return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
def pointwise_ops(self): a = torch.randn(4) b = torch.randn(4) t = torch.tensor([-1, -2, 3], dtype=torch.int8) r = torch.tensor([0, 1, 10, 0], dtype=torch.int8) t = torch.tensor([-1, -2, 3], dtype=torch.int8) s = torch.tensor([4, 0, 1, 0], dtype=torch.int8) f = torch.zeros(3) g = torch.tensor([-1, 0, 1]) w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) return ( torch.abs(torch.tensor([-1, -2, 3])), torch.absolute(torch.tensor([-1, -2, 3])), torch.acos(a), torch.arccos(a), torch.acosh(a.uniform_(1.0, 2.0)), torch.add(a, 20), torch.add(a, torch.randn(4, 1), alpha=10), torch.addcdiv(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.addcmul(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.angle(a), torch.asin(a), torch.arcsin(a), torch.asinh(a), torch.arcsinh(a), torch.atan(a), torch.arctan(a), torch.atanh(a.uniform_(-1.0, 1.0)), torch.arctanh(a.uniform_(-1.0, 1.0)), torch.atan2(a, a), torch.bitwise_not(t), torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.ceil(a), torch.clamp(a, min=-0.5, max=0.5), torch.clamp(a, min=0.5), torch.clamp(a, max=0.5), torch.clip(a, min=-0.5, max=0.5), torch.conj(a), torch.copysign(a, 1), torch.copysign(a, b), torch.cos(a), torch.cosh(a), torch.deg2rad( torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]])), torch.div(a, b), torch.divide(a, b, rounding_mode="trunc"), torch.divide(a, b, rounding_mode="floor"), torch.digamma(torch.tensor([1.0, 0.5])), torch.erf(torch.tensor([0.0, -1.0, 10.0])), torch.erfc(torch.tensor([0.0, -1.0, 10.0])), torch.erfinv(torch.tensor([0.0, 0.5, -1.0])), torch.exp(torch.tensor([0.0, math.log(2.0)])), torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])), torch.expm1(torch.tensor([0.0, math.log(2.0)])), torch.fake_quantize_per_channel_affine( torch.randn(2, 2, 2), (torch.randn(2) + 1) * 0.05, torch.zeros(2), 1, 0, 255, ), torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255), torch.float_power(torch.randint(10, (4, )), 2), torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])), torch.floor(a), # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])), # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4), torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2), torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.frac(torch.tensor([1.0, 2.5, -3.2])), torch.randn(4, dtype=torch.cfloat).imag, torch.ldexp(torch.tensor([1.0]), torch.tensor([1])), torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])), torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5), torch.lerp( torch.arange(1.0, 5.0), torch.empty(4).fill_(10), torch.full_like(torch.arange(1.0, 5.0), 0.5), ), torch.lgamma(torch.arange(0.5, 2, 0.5)), torch.log(torch.arange(5) + 10), torch.log10(torch.rand(5)), torch.log1p(torch.randn(5)), torch.log2(torch.rand(5)), torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logical_and(r, s), torch.logical_and(r.double(), s.double()), torch.logical_and(r.double(), s), torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)), torch.logical_not( torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)), torch.logical_not( torch.tensor([0.0, 1.0, -10.0], dtype=torch.double), out=torch.empty(3, dtype=torch.int16), ), torch.logical_or(r, s), torch.logical_or(r.double(), s.double()), torch.logical_or(r.double(), s), torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_xor(r, s), torch.logical_xor(r.double(), s.double()), torch.logical_xor(r.double(), s), torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logit(torch.rand(5), eps=1e-6), torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])), torch.i0(torch.arange(5, dtype=torch.float32)), torch.igamma(a, b), torch.igammac(a, b), torch.mul(torch.randn(3), 100), torch.multiply(torch.randn(4, 1), torch.randn(1, 4)), torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2), torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]), torch.nan_to_num(w), torch.nan_to_num(w, nan=2.0), torch.nan_to_num(w, nan=2.0, posinf=1.0), torch.neg(torch.randn(5)), # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]), torch.polygamma(1, torch.tensor([1.0, 0.5])), torch.polygamma(2, torch.tensor([1.0, 0.5])), torch.polygamma(3, torch.tensor([1.0, 0.5])), torch.polygamma(4, torch.tensor([1.0, 0.5])), torch.pow(a, 2), torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)), torch.rad2deg( torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])), torch.randn(4, dtype=torch.cfloat).real, torch.reciprocal(a), torch.remainder(torch.tensor([-3.0, -2.0]), 2), torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.round(a), torch.rsqrt(a), torch.sigmoid(a), torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sgn(a), torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sin(a), torch.sinc(a), torch.sinh(a), torch.sqrt(a), torch.square(a), torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2), torch.tan(a), torch.tanh(a), torch.trunc(a), torch.xlogy(f, g), torch.xlogy(f, g), torch.xlogy(f, 4), torch.xlogy(2, g), )
def normal_icdf(value, loc, scale): return (loc + scale * torch.erfinv(2 * value - 1) * (math.sqrt(2))).to( loc.device)
def phi_inv(x): return r2 * torch.erfinv((2 * x - 1).clamp(-SAFE_BOUND, SAFE_BOUND))
def estimate_knots_gaussian(data, interp_nbin, above_noise, edge_bins=0, derivclip=None, extrapolate='regression', alpha=(0.9, 0.99), KDE=True, bw_factor=1, batchsize=None): start = 100 / (interp_nbin - 2 * edge_bins + 1) end = 100 - start q1 = torch.linspace(start, end, interp_nbin - 2 * edge_bins, device=data.device) if edge_bins > 0: start = start / (edge_bins + 1) end = q1[0] - start q0 = torch.linspace(start, end, edge_bins, device=data.device) end = 100 - start start = q1[-1] + start q2 = torch.linspace(start, end, edge_bins, device=data.device) q = torch.cat((q0, q1, q2), dim=0) else: q = q1 x = Percentile(data.T, q).to(torch.get_default_dtype()) y = x.clone() deriv = torch.ones_like(x) for i in range(data.shape[1]): if above_noise[i]: if KDE: rho = kde(data[:, i], bw_factor=bw_factor, batchsize=batchsize) scale = (rho.covariance[0, 0] + 1)**0.5 y[i] = 2**0.5 * scale * torch.erfinv(2 * rho.cdf(x[i]) - 1) dy = y[i, 1:] - y[i, :-1] dx = x[i, 1:] - x[i, :-1] while (dy <= 0).any() or (dx <= 0).any(): select = torch.zeros(len(y[i]), dtype=bool, device=y.device) select[1:] = dy <= 0 select[1:] += dx <= 0 x[i, select] = torch.rand( torch.sum(select).item(), device=x.device) * (x[i, -1] - x[i, 0]) + x[i, 0] x[i] = torch.sort(x[i])[0] y[i] = 2**0.5 * scale * torch.erfinv(2 * rho.cdf(x[i]) - 1) dy = y[i, 1:] - y[i, :-1] dx = x[i, 1:] - x[i, :-1] else: y[i] = 2**0.5 * torch.erfinv(2 * q / 100. - 1) dy = y[i, 1:] - y[i, :-1] dx = x[i, 1:] - x[i, :-1] q0 = q.clone() while (dy <= 0).any() or (dx <= 0).any(): select = torch.zeros(len(y[i]), dtype=bool, device=y.device) select[1:] = dy <= 0 select[1:] += dx <= 0 q0[select] = torch.rand(torch.sum(select).item(), device=q.device) * 100 q0 = torch.sort(q0)[0] x[i] = Percentile(data[:, i], q0).to(torch.get_default_dtype()) y[i] = 2**0.5 * torch.erfinv(2 * q0 / 100. - 1) dy = y[i, 1:] - y[i, :-1] dx = x[i, 1:] - x[i, :-1] h = dx s = dy / dx deriv[i, 1:-1] = (s[:-1] * h[1:] + s[1:] * h[:-1]) / (h[1:] + h[:-1]) if derivclip == 1: deriv[i, 0] = 1 deriv[i, -1] = 1 else: if extrapolate == 'endpoint': endx = torch.min(data[:, i]) deriv[i, 0] = (2**0.5 * torch.erfinv(2 * torch.tensor( 1 / len(data), device=data.device) - 1) - y[i, 0]) / (endx - x[i, 0]) endx = torch.max(data[:, i]) deriv[i, -1] = (2**0.5 * torch.erfinv(2 * torch.tensor( 1 - 1 / len(data), device=data.device) - 1) - y[i, -1]) / (endx - x[i, -1]) elif extrapolate == 'regression': endx = torch.sort(data[data[:, i] < x[i, 0], i])[0] endy = 2**0.5 * torch.erfinv(2 * torch.linspace( 0.5, len(endx) - 0.5, len(endx), device=data.device) / len(data) - 1) - y[i, 0] endx -= x[i, 0] deriv[i, 0] = torch.sum(endx * endy) / torch.sum(endx * endx) endx = torch.sort(data[data[:, i] > x[i, -1], i], descending=True)[0] endy = 2**0.5 * torch.erfinv(2 * (1 - torch.linspace( 0.5, len(endx) - 0.5, len(endx), device=data.device) / len(data)) - 1) - y[i, -1] endx -= x[i, -1] deriv[i, -1] = torch.sum(endx * endy) / torch.sum(endx * endx) y[i] = (1 - alpha[0]) * y[i] + alpha[0] * x[i] deriv[i, 1:-1] = (1 - alpha[0]) * deriv[i, 1:-1] + alpha[0] deriv[i, 0] = (1 - alpha[1]) * deriv[i, 0] + alpha[1] deriv[i, -1] = (1 - alpha[1]) * deriv[i, -1] + alpha[1] if derivclip is not None and derivclip > 1: deriv[i, 0] = torch.clamp(deriv[i, 0], 1 / derivclip, derivclip) deriv[i, -1] = torch.clamp(deriv[i, -1], 1 / derivclip, derivclip) else: dx = x[i, 1:] - x[i, :-1] while (dx <= 0).any(): select = torch.zeros(len(x[i]), dtype=bool, device=x.device) select[1:] = dx <= 0 x[i, select] = torch.rand( torch.sum(select).item(), device=x.device) * (x[i, -1] - x[i, 0]) + x[i, 0] x[i] = torch.sort(x[i])[0] y[i] = x[i] dx = x[i, 1:] - x[i, :-1] return x, y, deriv
def icdf(self, value): self._validate_log_prob_arg(value) return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)