def compute_loss_GSWD(self, discriminator, optimizer, minibatch, rand_dist, g, r, num_projection, p=2): label = torch.full((minibatch.shape[0], ), 1, device=self.device) criterion = nn.BCELoss() data = minibatch.to(self.device) z_prior = rand_dist((data.shape[0], self.latent_size)).to(self.device) data_fake = self.decoder(z_prior) y_data, data = discriminator(data) errD_real = criterion(y_data, label) optimizer.zero_grad() errD_real.backward(retain_graph=True) optimizer.step() y_fake, data_fake = discriminator(data_fake) label.fill_(0) errD_fake = criterion(y_fake, label) optimizer.zero_grad() errD_fake.backward(retain_graph=True) optimizer.step() _gswd = generalized_sliced_wasserstein_distance( data.view(data.shape[0], -1), data_fake.view(data.shape[0], -1), g, r, num_projection, p, self.device) return _gswd
def compute_loss_GSWD(self,minibatch,rand_dist,g_function,r,num_projection,p=2): data = minibatch.to(self.device) z_prior = rand_dist((data.shape[0],self.latent_size)).to(self.device) data_fake= self.decoder(z_prior) _gswd = generalized_sliced_wasserstein_distance(data.view(data.shape[0],-1),data_fake.view(data.shape[0],-1) ,g_function,r, num_projection, p, self.device) return _gswd