def compute_lossDSWD(self, discriminator, optimizer, minibatch, rand_dist, num_projections, tnet, op_tnet, p=2, max_iter=100, lam=1): label = torch.full((minibatch.shape[0], ), 1, device=self.device) criterion = nn.BCELoss() data = minibatch.to(self.device) z_prior = rand_dist((data.shape[0], self.latent_size)).to(self.device) data_fake = self.decoder(z_prior) y_data, data = discriminator(data) errD_real = criterion(y_data, label) optimizer.zero_grad() errD_real.backward(retain_graph=True) optimizer.step() y_fake, data_fake = discriminator(data_fake) label.fill_(0) errD_fake = criterion(y_fake, label) optimizer.zero_grad() errD_fake.backward(retain_graph=True) optimizer.step() _dswd = distributional_sliced_wasserstein_distance( data.view(data.shape[0], -1), data_fake.view(data.shape[0], -1), num_projections, tnet, op_tnet, p, max_iter, lam, self.device) return _dswd
def compute_lossJDSWD(self, minibatch, rand_dist, num_projections, tnet, op_tnet, p=2, max_iter=100, lam=1): data = minibatch.to(self.device) z_data = self.encoder(data) z_prior = rand_dist((data.shape[0], self.latent_size)).to(self.device) data_fake = self.decoder(z_prior) _dswd = distributional_sliced_wasserstein_distance( torch.cat([z_data, data.view(data.shape[0], -1)], dim=1), torch.cat([z_prior, data_fake.view(data.shape[0], -1)], dim=1), num_projections, tnet, op_tnet, p, max_iter, lam, self.device, ) return _dswd