def forward(self, input_batch, return_more=True): assert len(input_batch.size()) == 4 assert input_batch.size()[-1] == self.samples.size()[-1] assert input_batch.size()[-2] == self.samples.size()[-2] assert input_batch.size()[-3] == self.samples.size()[-3] bs = input_batch.shape[0] input_batch = input_batch[:, None, ...].to(u.dev()) # (bs, 1, nch, x, y) def calc_dist(input_batch): dists = u.L2(self.samples, input_batch, axes=[2, 3, 4]) l2, best_ind_classes = torch.min(dists, 1) return l2, best_ind_classes l2s, best_ind_classes = u.auto_batch(self.max_bs, calc_dist, input_batch) # boring bookkeeping pred = self.get_classes(bs, input_batch, best_ind_classes) imgs = self.samples[0, best_ind_classes] # print(pred, imgs, l2s)\ if return_more: return pred, imgs, l2s else: return pred
def ELBOs(x_rec: torch.Tensor, samples_latent: torch.Tensor, x_orig: torch.Tensor, beta=1, dist_fct=squared_L2_loss, auto_batch_size=1600): """ Computes loss even when there is an extra label and sample dimension for the tensor: [bs=1, n_classes, n_samples, n_ch, nx, ny] ??? WHy is the loss (including the KLD) divided by the size ??? :param x_rec: shape (..., n_channels, nx, ny) :param samples_latent: (..., n_latent, 1, 1) :param x_orig: (..., n_channels, nx, ny) :param beta: KLD WEIGHT :param dist_fct: :param auto_batch_size: :return: """ n_ch, nx, ny = x_rec.shape[-3:] kld = KLD(samples_latent, sig_q=1.) if auto_batch_size is not None: rec_loss = u.auto_batch(auto_batch_size, dist_fct, x_orig, x_rec, axes=[-1, -2, -3]) # sum over nx, ny, n_ch else: rec_loss = dist_fct(x_orig, x_rec, axes=[-1, -2, -3]) elbo = rec_loss + beta * kld # del x_rec, x_orig, kld # del x_rec, samples_latent, x_orig return elbo / (n_ch * nx * ny)
def GD_inference(AEs, l_v_best, x_inp, clip=5, lr=0.01, n_iter=20, beta=1, dist_fct=loss_functions.squared_L2_loss): n_classes = len(AEs) # l_v_best are the latents # has shape (batch_size, n_classes == 10, n_latents == 8) + singleton dims # do gradient descent w.r.t. ELBO in latent space starting from l_v_best def gd_inference_b(l_v_best, x_inp, AEs, n_classes=10, clip=5, lr=0.01, n_iter=20, beta=1, dist_fct=loss_functions.squared_L2_loss): bs, n_ch, nx, ny = x_inp.shape with torch.enable_grad(): l_v_best = l_v_best.data.clone().detach().requires_grad_(True).to(u.dev()) opti = optim.Adam([l_v_best], lr=lr) for i in range(n_iter): ELBOs = [] all_recs = [] for j in range(n_classes): if i == n_iter - 1: l_v_best = l_v_best.detach() # no gradients in last run AEs[j].eval() rec = torch.sigmoid(AEs[j].Decoder.forward(l_v_best[:, j])) ELBOs.append(loss_functions.ELBOs(rec, # (bs, n_ch, nx, ny) l_v_best[:, j], # (bs, n_latent, 1, 1) x_inp, # (bs, n_ch, nx, ny) beta=beta, dist_fct=dist_fct)) if i == n_iter - 1: all_recs.append(rec.view(bs, 1, n_ch, nx, ny).detach()) ELBOs = torch.cat(ELBOs, dim=1) if i < n_iter - 1: loss = (torch.sum(ELBOs)) - 8./784./2 # historic reasons # backward opti.zero_grad() loss.backward() opti.step() l_v_best.data = u.clip_to_sphere(l_v_best.data, clip, channel_dim=2) else: opti.zero_grad() all_recs = torch.cat(all_recs, dim=1) return ELBOs.detach(), l_v_best.detach(), all_recs ELBOs, l_v_best, all_recs = u.auto_batch(1000, gd_inference_b, [l_v_best, x_inp], AEs, n_classes=n_classes, clip=clip, lr=lr, n_iter=n_iter, beta=beta, dist_fct=dist_fct) return ELBOs, l_v_best, all_recs
def ELBOs(x_rec: torch.Tensor, samples_latent: torch.Tensor, x_orig: torch.Tensor, beta=1, dist_fct=squared_L2_loss, auto_batch_size=1600): """ :param x_rec: shape (..., n_channels, nx, ny) :param samples_latent: (..., n_latent, 1, 1) :param x_orig: (..., n_channels, nx, ny) :param beta: :param dist_fct: :param auto_batch_size: :return: """ n_ch, nx, ny = x_rec.shape[-3:] kld = KLD(samples_latent, sig_q=1.) if auto_batch_size is not None: rec_loss = u.auto_batch(auto_batch_size, dist_fct, x_orig, x_rec, axes=[-1, -2, -3]) # sum over nx, ny, n_ch else: rec_loss = dist_fct(x_orig, x_rec, axes=[-1, -2, -3]) elbo = rec_loss + beta * kld # del x_rec, x_orig, kld # del x_rec, samples_latent, x_orig return elbo / (n_ch * nx * ny)
def forward(self, x, return_more=False): sample_distance_function = loss_functions.squared_L2_loss sgd_distance_function=loss_functions.squared_L2_loss torch.cuda.manual_seed_all(101) torch.manual_seed(101) np.random.seed(101) bs = x.shape[0] # if n_iter == 0: # Do something weird to skip grad descent??? # do the initial sampling and keep the top n_samples_grad: latent_samples = self.sampler(self.n_samples, self.n_latent, device=self.device, mus=None, fraction_to_dismiss=0.1, sample_sigma=1) # Generate an image from each sample for each class: gen_imgs = np.empty((self.n_labels, self.n_samples, self.n_ch, self.nx, self.ny)) self.CVAE.eval() for label in range(self.n_labels): tensor_label = torch.from_numpy(np.repeat(label, self.n_samples)).type(torch.LongTensor).to(self.device) gen_imgs[label, ...] = self.CVAE.decode(latent_samples.squeeze(), tensor_label).cpu().data.numpy() gen_imgs = tensor(gen_imgs).type(torch.FloatTensor).to(self.device) print('done creating samples') # calculate the likelihood for all samples with torch.no_grad(): all_ELBOs = loss_functions.ELBOs( gen_imgs.view(1, self.n_labels, self.n_samples, self.n_ch, self.nx, self.ny), latent_samples.view(1, 1, self.n_samples, self.n_latent, 1, 1), x.view(bs, 1, 1, self.n_ch, self.nx, self.ny), beta=self.beta, dist_fct=sample_distance_function, auto_batch_size=8) x = x.view(bs, self.n_ch, self.nx, self.ny) # Keep only the top samples min_val_labels, min_val_labels_idx = torch.topk(all_ELBOs, k=self.n_samples_grad, dim=2, largest=False, sorted=True) # min_val_c, min_val_c_args = torch.min(all_ELBOs, dim=2) indices = min_val_labels_idx.view(bs * self.n_labels * self.n_samples_grad) # just throw the extra samples into the batch size dimension since it shouldnt matter. But where do the 1, 1 go? latent_samples_best = latent_samples[indices].view(bs*self.n_samples_grad, self.n_labels, self.n_latent, 1, 1) # l_v_best shape: (bs, n_classes, 8, 1, 1) # l_v_best = GM.l_v[n_samples][indices].view(bs, n_classes, n_latent, 1, 1) # Do gradient descent on the n_samples_grad*bs in latent_samples_best. # This computes gradients in batches, through the auto batch(max bs=500) ELBOs, l_v_classes, reconsts = u.auto_batch(500, gd_inference_cvae, [latent_samples_best, x], self.CVAE, self.device, n_classes=self.n_labels, clip=self.clip, lr=self.lr, n_iter=self.n_iter, beta=self.beta, dist_fct=sgd_distance_function) ELBOs = self.rescale(ELBOs) # ????????? class specific fine-scaling if return_more: p_c = u.confidence_softmax(-ELBOs * self.logit_scale, const=self.confidence_level, dim=1) return p_c, ELBOs, l_v_classes, reconsts else: return -ELBOs[:, :, 0, 0] # like logits