def _d_kl_beta(p, q): alpha_p, beta_p = convert_parameters_beta(p) alpha_q, beta_q = convert_parameters_beta(q) dist_p = dist.Beta(alpha_p, beta_p) dist_q = dist.Beta(alpha_q, beta_q) d_kl = dist.kl_divergence(dist_p, dist_q).mean(-1) return d_kl
def forward(self, x): """ This function takes a batch of data `x` and returns: - nll: -\E_q[log p(x | z, A)] - q_z: a torch Bernoulli Distribution for q(z) - p_z: a torch Bernoulli Distribution for p(z | nu) *where nu ~ q(nu)* [because it's for the KL divergence] - q_nu: a torch Beta Distribution for q(nu) - p_nu: a torch Beta Distribution for p(nu) - q_a: a torch Normal Distribution (univariate / diagonal) for q(A) - p_a: a torch Normal Distribution for p(A) The negative ELBO can be computed as: -ELBO = nll + KL(q_z || p_z) + KL(q_nu || p_nu) """ batch_sz = x.size()[0] # p(nu) sz = self.beta_a.size() p_nu = distributions.Beta(torch.ones(sz) * self.alpha0, torch.ones(sz)) # compute q(nu) parameters, and take samples beta_a = F.softplus(self.beta_a) + 0.01 beta_b = F.softplus(self.beta_b) + 0.01 q_nu = distributions.Beta(beta_a, beta_b) nu = q_nu.rsample() # NOTE: differentiable sample! via Knowles et al. # p(z | nu) logpi = torch.cumsum((nu + SMALL).log(), dim=-1).unsqueeze(0).repeat(batch_sz, 1) p_z = distributions.Bernoulli(probs=logpi.exp()) # q(z) # machine/fp precision is higher near 0 than at 1 (crucial) probs = F.sigmoid(torch.clamp(self.encoder(x.view(-1, self.D)), -25, 9)) q_z = shared.STRelaxedBernoulli(temperature=0.1, probs=probs) # q_z = distributions.RelaxedBernoulli(temperature=0.2, probs=probs) z = q_z.rsample() q_z = distributions.Bernoulli(probs=probs) # self.z_log_prob = q_z.log_prob(z) # save for later # p(A) p_a = distributions.Normal(loc=0, scale=1) # NOTE: this is broadcast up # q(A) - this is wrong, it normalizes the wrong thing q_a = distributions.Normal(loc=self.A_mean, scale=(self.A_logvar / 2).exp()) A = self.A_mean # A = q_a.rsample() # now compute NLL: x_mean = torch.mm(z, A) nll = -(distributions.Normal(loc=x_mean, scale=self.sigma_n).log_prob(x)) return nll, p_nu, q_nu, p_z, q_z, p_a, q_a
def kl(self, dist_a, prior=None): if prior is None: # use standard reparamterizer return self._kld_beta_kerman_prior( dist_a['beta']['conc1'], dist_a['beta']['conc2'] ) # we have two distributions provided (eg: VRNN) return torch.sum(D.kl_divergence( D.Beta(dist_a['beta']['conc1'], dist_a['beta']['conc2']), D.Beta(prior['beta']['conc1'], prior['beta']['conc2']) ), -1)
def mutual_info(self, params, eps=1e-9): """ I(z_d; x) ~ H(z_prior, z_d) + H(z_prior) :param params: parameters of distribution :param eps: tolerance :returns: batch_size mutual information (prop-to) tensor. :rtype: torch.Tensor """ z_true = D.Beta(params['beta']['conc1'], params['beta']['conc2']) z_match = D.Beta(params['q_z_given_xhat']['beta']['conc1'], params['q_z_given_xhat']['beta']['conc2']) kl_proxy_to_xent = torch.sum(D.kl_divergence(z_match, z_true), dim=-1) return self.config['continuous_mut_info'] * kl_proxy_to_xent
def _reparametrize_beta(self, conc1, conc2, force=False): """ Internal function to reparameterize beta distribution using concentrations. :param conc1: concentration 1 :param conc2: concentration 2 :returns: reparameterized sample, distribution params :rtype: torch.Tensor, dict """ if self.training or force: beta = D.Beta(conc1, conc2).rsample() return beta, {'conc1': conc1, 'conc2': conc2} # can't use mean like in gaussian because beta mean can be > 1.0 return D.Beta(conc1, conc2).sample(), {'conc1': conc1, 'conc2': conc2}
def _kld_beta_kerman_prior(self, conc1, conc2): """ Internal function to do a KL-div against the prior. :param conc1: concentration 1. :param conc2: concentration 2. :returns: batch_size tensor of kld against prior. :rtype: torch.Tensor """ # prior = D.Beta(zeros_like(conc1) + 1/3, # zeros_like(conc2) + 1/3) prior = D.Beta(zeros_like(conc1) + 1.1, zeros_like(conc2) + 1.1) beta = D.Beta(conc1, conc2) return torch.sum(D.kl_divergence(beta, prior), -1)
def policy_to_action(self, alpha, beta): # alpha and beta must be non-negative float eps = 1e-6 # to avoid inf and nan p = dist.Beta(alpha + eps, beta + eps) action = p.sample() log_prob = p.log_prob(action) return action, log_prob
def _remix(x: List[torch.Tensor], remix_alpha: float) -> List[torch.Tensor]: # Create random Permutation index in range 0 -> length of the mini-batch. idx = torch.randperm(x[0].shape[0]) # Create beta dist over shape of alpha and sample from it mix = dist.Beta(remix_alpha + 1, remix_alpha).sample_n(x[0].shape[0]) x = [(mix * t) + ((1 - mix) * t[idx]) for t in x] return x
def stick_breaking(alpha0, k): """ This function breaks a stick into k pieces """ betas = dist.Beta(torch.tensor([1.]), torch.tensor([alpha0])).sample([k]).squeeze() remains = torch.cat( (torch.tensor([1.]), torch.cumprod(1 - betas[:-1], dim=0)), 0) p = betas * remains p /= p.sum() return p
def forward(self, x): batch_sz = x.size()[0] sz = self.q_pi_a.size() p_pi = distributions.Beta( torch.ones(sz) * self.p_pi_alpha, torch.ones(sz) * self.p_pi_beta) beta_a = F.softplus(self.q_pi_alpha) + 0.01 beta_b = F.softplus(self.q_pi_beta) + 0.01 q_pi = distributions.Beta(beta_a, beta_b) # Differentiable Sample Knowles et al. qpi_sample = q_pi.rsample() q_z = shared.STRelaxedBernoulli(temperature=0.1, probs=qpi_sample) z = q_z.rsample() q_z = distributions.Bernoulli(probs=qpi_sample) q_phi = distributions.Normal(loc=self.phi_mean, scale=(self.phi_logvar / 2).exp()) q_w = distributions.Normal(loc=self.w_mean, scale=(self.w_logvar / 2).exp()) # For now, just take the mean phi = q_phi.mean w = q_w.mean # Alternatively, sample # phi = q_phi.rsample() # w = q_w.rsample() # NLL sinbasis = torch.ones(K, N_SAMPLES) * torch.arange(0, N_SAMPLES, 1) for k in range(K): sinbasis[k] = torch.sin(sinbasis[k] * phi[k]) x_mean = torch.mm(torch.mul(z, w), sinbasis) # z and w multiplied elementwise nll = -(distributions.Normal(loc=x_mean, scale=self.sigma_n).log_prob(x)) return nll, p_pi, q_pi, q_z, q_phi, q_w, sinbasis
def log_likelihood(self, z, params): """ Log-likelihood of z induced under params. :param z: inferred latent z :param params: the params of the distribution :returns: log-likelihood :rtype: torch.Tensor """ return D.Beta(params['beta']['conc1'], params['beta']['conc2']).log_prob(z)
def sample_params(self, n_sample=torch.Size([])): clusters = self.cluster_distr.rsample(n_sample) params = self.cluster_to_params_graph(clusters) alpha_hsl0 = F.softplus(params[0:3]) beta_hsl0 = F.softplus(params[3:6]) hsl0 = td.Beta(alpha_hsl0, beta_hsl0).rsample() alpha_hsl1 = F.softplus(params[6:9]) beta_hsl1 = F.softplus(params[9:12]) hsl1 = td.Beta(alpha_hsl1, beta_hsl1).rsample() shape_trans01 = F.softplus(params[12:15]) scale_trans01 = F.softplus(params[15:18]) trans01 = td.Gamma(shape_trans01, scale_trans01).rsample() shape_trans10 = F.softplus(params[18:21]) scale_trans10 = F.softplus(params[21:24]) trans10 = td.Gamma(shape_trans10, scale_trans10).rsample() return hsl0, hsl1, trans01, trans10
def __init__(self, in_features: int, out_channels: int, num_repetitions: int = 1, dropout=0.0): """Creat a beta layer. Args: out_channels: Number of parallel representations for each input feature. in_features: Number of input features. num_repetitions: Number of parallel repetitions of this layer. """ super().__init__(in_features, out_channels, num_repetitions, dropout) # Create beta parameters self.concentration0 = nn.Parameter(torch.rand(1, in_features, out_channels, num_repetitions)) self.concentration1 = nn.Parameter(torch.rand(1, in_features, out_channels, num_repetitions)) self.beta = dist.Beta(concentration0=self.concentration0, concentration1=self.concentration1)
def __init__(self, multiplicity, in_features, dropout=0.0): """Creat a beta layer. Args: multiplicity: Number of parallel representations for each input feature. in_features: Number of input features. """ super().__init__(multiplicity, in_features, dropout) # Create beta parameters self.concentration0 = nn.Parameter(torch.rand(1, in_features, multiplicity)) self.concentration1 = nn.Parameter(torch.rand(1, in_features, multiplicity)) self.beta = dist.Beta( concentration0=self.concentration0, concentration1=self.concentration1 )
def with_beta_dist( cls: type[RandomMixUp], alpha: float = 0.2, *, beta: float | None = None, mode: MixUpMode | str = MixUpMode.linear, p: float = 1.0, num_classes: int | None = None, inplace: bool = False, featurewise: bool = False, ) -> RandomMixUp[td.Beta]: """ Instantiate a :class:`RandomMixUp` with a Beta-distribution sampler. :param alpha: 1st concentration parameter of the distribution. Must be positive :param beta: 2nd concentration parameter of the distribution. Must be positive. If ``None``, then the parameter will be set to ``alpha``. :param mode: Which mode to use to mix up samples: geometric or linear. .. note:: The (weighted) geometric mean, enabled by ``mode=geometric``, is only valid for positive inputs. :param p: The probability with which the transform will be applied to a given sample. :param num_classes: The total number of classes in the dataset that needs to be specified if wanting to mix up targets that are label-enoded. Passing label-encoded targets without specifying ``num_classes`` will result in a RuntimeError. :param featurewise: Whether to sample sample feature-wise instead of sample-wise. :param inplace: Whether the transform should be performed in-place. :return: A :class:`RandomMixUp` instance with ``lambda_sampler`` set to a Beta-distribution with ``concentration1=alpha`` and ``concentration0=beta``. """ beta = alpha if beta is None else beta lambda_sampler = td.Beta(concentration0=alpha, concentration1=beta) return cls( lambda_sampler=lambda_sampler, mode=mode, p=p, num_classes=num_classes, inplace=inplace, featurewise=featurewise, )
def __init__( self, alpha: float = 1.0, *, p: float = 0.5, num_classes: int | None = None, inplace: bool = False, seed: Optional[int] = None, ) -> None: """ :param alpha: hyperparameter of the Beta distribution used for sampling the areas of the bounding boxes. :param num_classes: The total number of classes in the dataset that needs to be specified if wanting to mix up targets that are label-enoded. Passing label-encoded targets without specifying ``num_classes`` will result in a RuntimeError. :param p: The probability with which the transform will be applied to a given sample. :param inplace: Whether the transform should be performed in-place. :param seed: The PRNG seed to use for sampling pairs and bounding-box coordinates. :raises ValueError: if ``p`` is not in the range [0, 1] , if ``num_classes < 1``, or if ``alpha`` is not a positive real number. """ super().__init__() if not 0 <= p <= 1: raise ValueError("'p' must be in the range [0, 1].") self.p = p if alpha < 0: raise ValueError("'alpha' must be positive.") self.alpha = alpha if (num_classes is not None) and num_classes < 1: raise ValueError(f"{ num_classes } must be greater than 1.") self.lambda_sampler = td.Beta(concentration0=alpha, concentration1=alpha) self.num_classes = num_classes self.inplace = inplace self.seed = seed
def prior(self, batch_size, **kwargs): """ Returns a Kerman beta prior. Kerman, J. (2011). Neutral noninformative and informative conjugate beta and gamma prior distributions. Electronic Journal of Statistics, 5, 1450-1470. :param batch_size: the number of prior samples :returns: prior :rtype: torch.Tensor """ conc1 = Variable( same_type(self.config['half'], self.config['cuda'])( batch_size, self.output_size ).zero_() + 1/3 ) conc2 = Variable( same_type(self.config['half'], self.config['cuda'])( batch_size, self.output_size ).zero_() + 1/3 ) return D.Beta(conc1, conc2).sample()
def act(self, x): VARIANCE = 0.25 # Set a 4D shape x = x.view(1, 1, self.IMAGE_SIZE, self.IMAGE_SIZE) # First get the hidden representation mu, log_sigma = self.encode(x) # Compute alpha and beta for the beta distribution z = F.relu(self.linear1(mu)) if self.action_dist == 'beta': alpha = F.softplus(self.linear2(z)) + 1 beta = F.softplus(self.linear3(z)) + 1 # Sample the beta distribution a_dist = dist.Beta(alpha, beta) actions = a_dist.sample()[0] log_proba = torch.sum(a_dist.log_prob(actions)) # Now move the 3 beta samples in the action space # Note: only the first action, steer, must be rescaled actions[0] = actions[0] * 2 - 1 elif self.action_dist == 'gaussian': raise Exception('TODO') else: raise Exception('Unrecognized action distribution.') return actions.numpy(), log_proba
def go(arg): tbw = SummaryWriter(log_dir=arg.tb_dir) ## Load the data if arg.task == 'mnist': transform = tfs.Compose([tfs.Pad(padding=2), tfs.ToTensor()]) trainset = torchvision.datasets.MNIST(root=arg.data_dir, train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.MNIST(root=arg.data_dir, train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 1, 32, 32 elif arg.task == 'cifar10': trainset = torchvision.datasets.CIFAR10(root=arg.data_dir, train=True, download=True, transform=tfs.ToTensor()) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root=arg.data_dir, train=False, download=True, transform=tfs.ToTensor()) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 3, 32, 32 elif arg.task == 'cifar-gs': transform = tfs.Compose([tfs.Grayscale(), tfs.ToTensor()]) trainset = torchvision.datasets.CIFAR10(root=arg.data_dir, train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root=arg.data_dir, train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 1, 32, 32 elif arg.task == 'imagenet64': transform = tfs.Compose([tfs.ToTensor()]) trainset = torchvision.datasets.ImageFolder(root=arg.data_dir + os.sep + 'train', transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.ImageFolder(root=arg.data_dir + os.sep + 'valid', transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 3, 64, 64 else: raise Exception('Task {} not recognized.'.format(arg.task)) ## Set up the model out_channels = C if (arg.rloss == 'gauss' or arg.rloss == 'laplace' or arg.rloss == 'signorm' or arg.rloss == 'siglaplace' or arg.rloss == 'beta') and arg.scale is None: out_channels = 2 * C print(f'out channels: {out_channels}') encoder = Encoder(zsize=arg.zsize, colors=C) decoder = Decoder(zsize=arg.zsize, out_channels=out_channels, mult=arg.mult) if arg.testmodel: decoder = Test(out_channels=out_channels, height=H, width=W) if torch.cuda.is_available(): encoder.cuda() decoder.cuda() opt = torch.optim.Adam(lr=arg.lr, params=list(encoder.parameters()) + list(decoder.parameters())) if arg.esched is not None: start, end = int(arg.esched[0] * arg.epochs), (arg.esched[1] * arg.epochs) slope = 1.0 / (end - start) for epoch in range(arg.epochs): if arg.esched is not None: weight = (epoch - start) * slope weight = np.clip(weight, 0, 1) else: weight = 1.0 for i, (input, _) in enumerate(tqdm.tqdm(trainloader)): if arg.limit is not None and i * arg.batch_size > arg.limit: break # Prepare the input b, c, w, h = input.size() if torch.cuda.is_available(): input = input.cuda() # Forward pass if not arg.testmodel: zs = encoder(input) kloss = kl_loss(zs[:, :arg.zsize], zs[:, arg.zsize:]) z = sample(zs[:, :arg.zsize], zs[:, arg.zsize:]) out = decoder(z) else: out = decoder(input) kloss = 0 # compute -log p per dimension if arg.rloss == 'xent': # binary cross-entropy (not a proper log-prob) rloss = F.binary_cross_entropy_with_logits(out, input, reduction='none') elif arg.rloss == 'bdist': # xent + correction rloss = F.binary_cross_entropy_with_logits(out, input, reduction='none') za = out.abs() eza = (-za).exp() # - np.log(za) + np.log1p(-eza + EPS) - np.log1p(eza + EPS) logpart = -(za + arg.eps).log() + (-eza + arg.eps).log1p() - ( eza + arg.eps).log1p() rloss = rloss + weight * logpart elif arg.rloss == 'gauss': # xent + correction if arg.scale is None: means = T.sigmoid(out[:, :c, :, :]) vars = F.sigmoid(out[:, c:, :, :]) rloss = GAUSS_CONST + vars.log() + ( 1.0 / (2.0 * vars.pow(2.0))) * (input - means).pow(2.0) else: means = T.sigmoid(out[:, :c, :, :]) var = arg.scale rloss = GAUSS_CONST + ln( var) + (1.0 / (2.0 * (var * var))) * (input - means).pow(2.0) elif arg.rloss == 'mse': means = T.sigmoid(out[:, :c, :, :]) rloss = (input - means).pow(2.0) elif arg.rloss == 'mae': means = T.sigmoid(out[:, :c, :, :]) rloss = (input - means).abs() elif arg.rloss == 'laplace': # xent + correction if arg.scale is None: means = T.sigmoid(out[:, :c, :, :]) vars = F.softplus(out[:, c:, :, :]) rloss = (2.0 * vars).log() + (1.0 / vars) * (input - means).abs() else: means = T.sigmoid(out[:, :c, :, :]) var = arg.scale rloss = ln(2.0 * var) + (1.0 / var) * (input - means).abs() elif arg.rloss == 'signorm': if arg.scale is None: mus = out[:, :c, :, :] sgs, lsgs = T.exp( out[:, c:, :, :] * arg.varmult), out[:, c:, :, :] * arg.varmult else: mus = out[:, :c, :, :] sgs, lsgs = arg.scale, math.log(arg.scale) y = input lny = torch.log(y + arg.eps) ln1y = torch.log(1 - y + arg.eps) x = lny - ln1y rloss = lny + ln1y + lsgs + GAUSS_CONST + \ 0.5 * (1.0 / (sgs * sgs + arg.eps)) * (x - mus) ** 2 elif arg.rloss == 'siglaplace': if arg.scale is None: mus = out[:, :c, :, :] sgs, lsgs = T.exp( out[:, c:, :, :] * arg.varmult), out[:, c:, :, :] * arg.varmult else: mus = out[:, :c, :, :] sgs, lsgs = arg.scale, math.log(arg.scale) y = input lny = torch.log(y + arg.eps) ln1y = torch.log(1 - y + arg.eps) x = lny - ln1y rloss = lny + ln1y + lsgs + math.log(2.0) + \ (x - mus).abs() / sgs elif arg.rloss == 'beta': mean = T.sigmoid(out[:, :c, :, :]) mult = F.softplus(out[:, c:, :, :] + arg.beta_add) + (1.0 / (mean + arg.eps)) + arg.eps alpha = mean * mult beta = (1 - mean) * mult part = alpha.lgamma() + beta.lgamma() - (alpha + beta).lgamma() x = input rloss = -(alpha - 1) * (x + arg.eps).log() - (beta - 1) * ( 1 - x + arg.eps).log() + part else: raise Exception( f'reconstruction loss {arg.rloss} not recognized.') if contains_nan(rloss): if arg.rloss == 'beta': print('part contains nan', contains_nan(part)) print('alpha contains nan', contains_nan(alpha)) print('beta contains nan', contains_nan(beta)) print('log x contains nan', contains_nan((x + arg.eps).log())) print('log (1-x) contains nan', contains_nan((1 - x + arg.eps).log())) raise Exception('rloss contains nan') rloss = rloss.reshape(b, -1).sum(dim=1) # reduce loss = (rloss + kloss).mean() opt.zero_grad() loss.backward() opt.step() with torch.no_grad(): N = 5 # Plot reconstructions inputs, _ = next(iter(testloader)) if torch.cuda.is_available(): inputs = inputs.cuda() b, c, h, w = inputs.size() if not arg.testmodel: zs = encoder(inputs) res = decoder(zs[:, :arg.zsize]) else: res = decoder(inputs) outputs = res[:, :c, :, :] means = T.sigmoid(outputs) samples = None if arg.rloss == 'signorm' and out_channels > c: means = res[:, :c, :, :] vars = res[:, c:, :, :] * arg.varmult dist = ds.Normal(means, vars) samples = T.sigmoid(dist.sample()) means = T.sigmoid(dist.mean) if arg.rloss == 'siglaplace' and out_channels > c: means = res[:, :c, :, :] vars = res[:, c:, :, :] * arg.varmult dist = ds.Laplace(means, vars) samples = T.sigmoid(dist.sample()) means = T.sigmoid(dist.mean) if arg.rloss == 'beta': mean = T.sigmoid(res[:, :c, :, :]) mult = (res[:, c:, :, :] + arg.beta_add).exp() + (1.0 / mean) + arg.eps alpha = mean * mult beta = (1 - mean) * mult dist = ds.Beta(alpha, beta) samples = dist.sample() means = dist.mean vars = dist.variance plt.figure(figsize=(5, 4)) for i in range(N): ax = plt.subplot(4, N, i + 1) inp = inputs[i].permute(1, 2, 0).cpu().numpy() if c == 1: inp = inp.squeeze() ax.imshow(inp, cmap='gray_r') if i == 0: ax.set_title('input') plt.axis('off') ax = plt.subplot(4, N, N + i + 1) outp = means[i].permute(1, 2, 0).cpu().numpy() if c == 1: outp = outp.squeeze() ax.imshow(outp, cmap='gray_r') if i == 0: ax.set_title('means/modes') plt.axis('off') if samples is not None: # plot samples ax = plt.subplot(4, N, 2 * N + i + 1) outp = samples[i].permute(1, 2, 0).detach().cpu().numpy() if c == 1: outp = outp.squeeze() ax.imshow(outp, cmap='gray_r') if i == 0: ax.set_title('sampled') plt.axis('off') if out_channels > c: # plot the variance (or other uncertainty) ax = plt.subplot(4, N, 3 * N + i + 1) outp = vars[i].permute(1, 2, 0).detach().cpu().numpy() if c == 1: outp = outp.squeeze() ax.imshow(outp, cmap='copper') if i == 0: ax.set_title('var') plt.axis('off') plt.tight_layout() plt.savefig(f'reconstruction.{arg.rloss}.{epoch:03}.png') if arg.zsize == 2: # latent space plot N = 2000 # gather up first 200 batches into one big tensor numbatches = N // arg.batch_size images, labels = [], [] for i, (ims, lbs) in enumerate(testloader): images.append(ims) labels.append(lbs) if i > numbatches: break images, labels = torch.cat(images, dim=0), torch.cat(labels, dim=0) imagesg = images if torch.cuda.is_available(): imagesg = imagesg.cuda() n, c, h, w = images.size() z = encoder(imagesg) latents = z[:, :2].data.detach().cpu() mn, mx = latents.min(), latents.max() size = 1.0 * (mx - mn) / math.sqrt(n) # Change 0.75 to any value between ~ 0.5 and 1.5 to make the digits smaller or bigger fig = plt.figure(figsize=(8, 8)) # colormap for the images norm = mpl.colors.Normalize(vmin=0, vmax=9) cmap = mpl.cm.get_cmap('tab10') for i in range(n): x, y = latents[i, 0:2] l = labels[i] im = images[i, :] alpha_im = im.permute(1, 2, 0).detach().cpu().numpy() color = cmap(norm(l)) color_im = np.asarray(color)[None, None, :3] color_im = np.broadcast_to(color_im, (h, w, 3)) # -- To make the digits transparent we make them solid color images and use the # actual data as an alpha channel. # color_im: 3-channel color image, with solid color corresponding to class # alpha_im: 1-channel grayscale image corrsponding to input data im = np.concatenate([color_im, alpha_im], axis=2) plt.imshow(im, extent=(x, x + size, y, y + size)) plt.xlim(mn, mx) plt.ylim(mn, mx) plt.savefig(f'latent.{arg.rloss}.{epoch:03}.png')
def __init__(self, a): dist = dists.Beta(a[0], a[1]) super().__init__(dist, "beta", 2, a[0], a[1])