def compute_losses(self): """Compute losses.""" self.reconstruction_loss_a = losses.reconstruction_loss( real_images=self.input_a, generated_images=self.ae_images_a) self.reconstruction_loss_b = losses.reconstruction_loss( real_images=self.input_b, generated_images=self.ae_images_b) self.lsgan_loss_fake_a = losses.lsgan_loss_generator( self.prob_fake_a_is_real) self.lsgan_loss_fake_b = losses.lsgan_loss_generator( self.prob_fake_b_is_real) self.cycle_consistency_loss_a = losses.cycle_consistency_loss( real_images=self.input_a, generated_images=self.cycle_images_a) self.cycle_consistency_loss_b = losses.cycle_consistency_loss( real_images=self.input_b, generated_images=self.cycle_images_b) self.g_loss = self._rec_lambda_a * self.reconstruction_loss_a + \ self._rec_lambda_b * self.reconstruction_loss_b + \ self._cycle_lambda_a * self.cycle_consistency_loss_a + \ self._cycle_lambda_b * self.cycle_consistency_loss_b + \ self._lsgan_lambda_a * self.lsgan_loss_fake_a + \ self._lsgan_lambda_b * self.lsgan_loss_fake_b self.d_loss_A = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_a_is_real, prob_fake_is_real=self.prob_fake_pool_a_is_real) self.d_loss_B = losses.lsgan_loss_discriminator( prob_real_is_real=self.prob_real_b_is_real, prob_fake_is_real=self.prob_fake_pool_b_is_real) self.model_vars = tf.trainable_variables() d_a_vars = [ var for var in self.model_vars if 'd1' in var.name or 'd_shared' in var.name ] d_b_vars = [ var for var in self.model_vars if 'd2' in var.name or 'd_shared' in var.name ] g_vars = [ var for var in self.model_vars if 'ae1' in var.name or 'ae2' in var.name or 'ae_shared' in var.name ] optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) self.d_A_trainer = optimizer.minimize(self.d_loss_A, var_list=d_a_vars) self.d_B_trainer = optimizer.minimize(self.d_loss_B, var_list=d_b_vars) self.g_trainer = optimizer.minimize(self.g_loss, var_list=g_vars) self.create_summaries()
def _train_generator_step(self, real_img, noise, img_from_prev_scale, rec_from_prev_scale): # Convert images to batch of images with one elements (to be fed to the models and the loss functions) # batch_fake_img = fake_img[np.newaxis, :, :, :] with tf.GradientTape() as tape: fake_img = self._generate_image_for_training( noise, img_from_prev_scale) batch_fake_img = fake_img[np.newaxis, :, :, :] # Get the discriminator logits for fake patches fake_logits = self.critic(batch_fake_img, training=True) # Calculate the generator adversarial loss adv_loss = generator_wass_loss(fake_logits) # Calculate the reconstruction loss and update the noise sigma with it reconstructed = self._reconstruct_image_for_training( rec_from_prev_scale) rec_loss = reconstruction_loss(reconstructed, real_img) loss = adv_loss + self.rec_loss_weight * rec_loss # Get the gradients w.r.t the generator loss gen_gradient = tape.gradient(loss, self.generator.trainable_variables) # Update the weights of the generator using the generator optimizer self.g_optimizer.apply_gradients( zip(gen_gradient, self.generator.trainable_variables)) return adv_loss, self.rec_loss_weight * rec_loss
def test_step(model, config, inputs, logger=None): """Reconstruction Test step during training.""" outputs = inputs.clone().detach() with torch.no_grad(): (preds, priors, posteriors), stored_vars = model( inputs, config, False, ) # Accumulate preds and select targets targets = outputs[:, config['n_ctx']:] # Compute the reconstruction and prior loss loss_rec = losses.reconstruction_loss(config, preds, targets) if config['beta'] > 0: loss_prior = losses.kl_loss(config, priors, posteriors) loss = loss_rec + config['beta'] * loss_prior else: loss = loss_rec # Logs if logger is not None: logger.scalar('test_loss_rec', loss_rec.item()) logger.scalar('test_loss', loss.item()) if config['beta'] > 0: logger.scalar('test_loss_prior', loss_prior.item())
def train_step(self, data): with tf.GradientTape() as tape: z_mean, z_var, z = self.encoder(data) reconstruction = self.decoder(z) reconstruction_loss_val = reconstruction_loss(data, reconstruction) kl_loss = kl_divergence(z_mean, z_var) total_loss = reconstruction_loss_val + kl_loss grads = tape.gradient(total_loss, self.trainable_weights) self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss_val) self.kl_loss_tracker.update_state(kl_loss) return { 'loss': self.total_loss_tracker.result(), 'reconstruction_loss': self.reconstruction_loss_tracker.result(), 'kl_loss': self.kl_loss_tracker.result() }
def train_step(model, config, inputs, optimizer, batch_idx, logger=None): """Training step for the model.""" outputs = inputs.clone().detach() # Forward pass (preds, priors, posteriors), stored_vars = model(inputs, config, False) # Accumulate preds and select targets targets = outputs[:, config['n_ctx']:] # Compute the reconstruction loss loss_rec = losses.reconstruction_loss(config, preds, targets) # Compute the prior loss if config['beta'] > 0: loss_prior = losses.kl_loss(config, priors, posteriors) loss = loss_rec + config['beta'] * loss_prior else: loss_prior = 0. loss = loss_rec # Backward pass and optimizer step optimizer.zero_grad() if config['apex']: from apex import amp with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() # Logs if logger is not None: logger.scalar('train_loss_rec', loss_rec.item()) logger.scalar('train_loss', loss.item()) if config['beta'] > 0: logger.scalar('train_loss_prior', loss_prior.item()) return preds, targets, priors, posteriors, loss_rec, loss_prior, loss, stored_vars
def compute_losses(self, img_real, label_org, label_trg): # discriminator loss out_src_real, out_cls_real = self.discriminator(img_real) img_fake = self.generator([img_real, label_trg]) out_src_fake, out_cls_fake = self.discriminator(img_fake) d_adv_loss = -adversarial_loss(out_src_fake, out_src_real) d_cls_loss = classification_loss(label_org, out_cls_real) alpha = tf.random.uniform(shape=[img_real.shape.as_list()[0], 1, 1, 1]) img_hat = alpha * img_real + (1 - alpha) * img_fake d_gp_loss = self.gradient_penalty(img_hat) self.d_loss = d_adv_loss + self.lambda_cls * d_cls_loss + self.lambda_gp * d_gp_loss # generator loss g_adv_loss = adversarial_loss(out_src_fake) g_cls_loss = classification_loss(label_trg, out_cls_fake) img_rec = self.generator([img_fake, label_org]) g_rec_loss = reconstruction_loss(img_real, img_rec) self.g_loss = g_adv_loss + self.lambda_cls * g_cls_loss + self.lambda_rec * g_rec_loss
def test_fashion_mnist(): CUDA = torch.cuda.is_available() net = CapsNet(1, 6 * 6 * 32) if CUDA: net.cuda() print(net) print("# parameters: ", sum(param.numel() for param in net.parameters())) transform = transforms.Compose([transforms.ToTensor()]) trainset = torchvision.datasets.FashionMNIST(root='./data/fashion', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True) testset = torchvision.datasets.FashionMNIST(root='./data/fashion', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False) optimizer = Adam(net.parameters()) n_epochs = 30 print_every = 200 if CUDA else 2 for epoch in range(n_epochs): train_acc = 0. time_start = time.time() for i, data in enumerate(trainloader, 0): inputs, labels = data labels_one_hot = torch.eye(10).index_select(dim=0, index=labels) inputs, labels_one_hot, labels = Variable(inputs), Variable( labels_one_hot), Variable(labels) if CUDA: inputs, labels_one_hot, labels = inputs.cuda( ), labels_one_hot.cuda(), labels.cuda() optimizer.zero_grad() class_probs, recons = net(inputs, labels) acc = torch.mean((labels == torch.max(class_probs, -1)[1]).double()) train_acc += acc.data[0] loss = (margin_loss(class_probs, labels_one_hot) + 0.0005 * reconstruction_loss(recons, inputs)) loss.backward() optimizer.step() if (i + 1) % print_every == 0: print( '[epoch {}/{}, batch {}] train_loss: {:.5f}, train_acc: {:.5f}' .format(epoch + 1, n_epochs, i + 1, loss.data[0], acc.data[0])) test_acc = 0. for j, data in enumerate(testloader, 0): inputs, labels = data labels_one_hot = torch.eye(10).index_select(dim=0, index=labels) inputs, labels_one_hot, labels = Variable(inputs), Variable( labels_one_hot), Variable(labels) if CUDA: inputs, labels_one_hot, labels = inputs.cuda( ), labels_one_hot.cuda(), labels.cuda() class_probs, recons = net(inputs) acc = torch.mean((labels == torch.max(class_probs, -1)[1]).double()) test_acc += acc.data[0] print( '[epoch {}/{} done in {:.2f}s] train_acc: {:.5f} test_acc: {:.5f}'. format(epoch + 1, n_epochs, (time.time() - time_start), train_acc / (i + 1), test_acc / (j + 1)))
def main(args=None): device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu') # shutil.rmtree("./data") G = Generator() D = Discriminator() optimizerG = torch.optim.Adam(G.parameters(), lr=1e-4) optimizerD = torch.optim.RMSprop(D.parameters(), lr=1e-4) win = lambda x: torch.ones(x) if not args.hann else torch.hann_window(x) dataset = se_dataset.SEDataset(args.clean_dir, args.noisy_dir, 0.95, cache_dir=args.cache_dir, slice_size=args.slice_size, max_samples=1000) dloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size) G = Generator().to(device) D = Discriminator().to(device) optimizerG = torch.optim.Adam(G.parameters(), lr=1e-4) optimizerD = torch.optim.RMSprop(D.parameters(), lr=1e-4) torch.autograd.set_detect_anomaly(True) G.to(device) D.to(device) g_loss = [] d_loss = [] writer = SummaryWriter("./logs") train = False if train: for epoch in range(args.num_epochs): print("EPOCH", epoch) if epoch == (args.num_epochs // 3): optimizerD.param_groups[0]['lr'] = optimizerD.param_groups[0]['lr'] / 10 optimizerG.param_groups[0]['lr'] = optimizerG.param_groups[0]['lr'] / 10 for bidx, batch in tqdm(enumerate(dloader, 1), total=len(dloader)): uttname, clean, noisy, slice_idx = batch clean, noisy = clean.to(device), noisy.to(device) # Get real data if args.task == 'sr': real_full = full_lps(clean).to(device) lf = lowres_lps(clean, args.rate_div).to(device) else: real_full = full_lps(clean).detach().clone().to(device) lf = full_lps(noisy)[:,:129,:].detach().clone().to(device) # Update D if epoch >= (args.num_epochs // 3): for p in D.parameters(): p.requires_grad = True fake_hf = G(lf).to(device) fake_full = torch.cat([lf, fake_hf], 1) fake_logit, fake_prob = D(fake_full) real_logit, real_prob = D(real_full) optimizerD.zero_grad() gan_loss_d = disc_loss(fake_prob, real_prob) reg_d = regularization_term(fake_prob, real_prob, fake_logit, real_logit) lossD = gan_loss_d + reg_d lossD.backward() writer.add_scalar("loss/D_loss", lossD) writer.add_scalar("loss/D_loss_reg", reg_d) writer.add_scalar("loss/D_loss_gan", gan_loss_d) optimizerD.step() d_loss.append(lossD.item()) # Update G for p in D.parameters(): p.requires_grad = False fake_hf = G(lf) fake_full = torch.cat([lf, fake_hf], 1) fake_logit, fake_prob = D(fake_full) real_logit, real_prob = D(real_full) gan = None if epoch >= (args.num_epochs // 3): gan = gen_loss(fake_prob) rec_loss = reconstruction_loss(real_full, fake_full) lossG = args.lambd * rec_loss - gan else: rec_loss = reconstruction_loss(real_full, fake_full) lossG = args.lambd * rec_loss writer.add_scalar("loss/G_loss", lossG) writer.add_scalar("rec_loss/rec_loss", rec_loss) lossG.backward() optimizerG.step() g_loss.append(lossG.item()) with open("result.pth", "wb") as f: torch.save({ "g_state_dict": G.state_dict(), "d_state_dict": D.state_dict(), }, f) else: with open("./result.pth", "rb") as f: G.load_state_dict(torch.load(f)["g_state_dict"]) for bidx, batch in tqdm(enumerate(dloader, 1), total=len(dloader)): uttname, clean, noisy, slice_idx = batch clean, noisy = clean.to(device), noisy.to(device) real_full = full_lps(clean).to(device) lf = lowres_lps(clean, args.rate_div).to(device) fake_hf = G(lf) fake_full = torch.cat([lf, fake_hf], 1) break
def test_cifar10(): CUDA = torch.cuda.is_available() net = CapsNet(8 * 8 * 32, [3, 32, 32]) if CUDA: net.cuda() print(net) print("# parameters: ", sum(param.numel() for param in net.parameters())) transform = transforms.Compose([transforms.ToTensor()]) trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True) testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False) optimizer = Adam(net.parameters()) n_epochs = 200 print_every = 200 if CUDA else 2 display_every = 10 for epoch in range(n_epochs): train_acc = 0. time_start = time.time() for i, data in enumerate(trainloader, 0): inputs, labels = data masked_inputs = put_mask(inputs) labels_one_hot = torch.eye(10).index_select(dim=0, index=labels) if CUDA: masked_inputs, inputs, labels_one_hot, labels = masked_inputs.cuda( ), inputs.cuda(), labels_one_hot.cuda(), labels.cuda() optimizer.zero_grad() class_probs, recons = net(masked_inputs, labels) acc = torch.mean((labels == torch.max(class_probs, -1)[1]).double()) train_acc += acc.data.item() loss = (margin_loss(class_probs, labels_one_hot) + 0.0005 * reconstruction_loss(recons, inputs)) loss.backward() optimizer.step() if (i + 1) % print_every == 0: print( '[epoch {}/{}, batch {}] train_loss: {:.5f}, train_acc: {:.5f}' .format(epoch + 1, n_epochs, i + 1, loss.data.item(), acc.data.item())) test_acc = 0. for j, data in enumerate(testloader, 0): inputs, labels = data masked_inputs = put_mask(inputs) labels_one_hot = torch.eye(10).index_select(dim=0, index=labels) if CUDA: masked_inputs, inputs, labels_one_hot, labels = masked_inputs.cuda( ), inputs.cuda(), labels_one_hot.cuda(), labels.cuda() class_probs, recons = net(masked_inputs) if (j + 1) % display_every == 0: display(inputs[0].cpu(), masked_inputs[0].cpu(), recons[0].cpu().detach()) acc = torch.mean((labels == torch.max(class_probs, -1)[1]).double()) test_acc += acc.data.item() print( '[epoch {}/{} done in {:.2f}s] train_acc: {:.5f} test_acc: {:.5f}'. format(epoch + 1, n_epochs, (time.time() - time_start), train_acc / (i + 1), test_acc / (j + 1)))
def reconstruction_loss(X, x, x_raw, W): return losses.reconstruction_loss(X, x, x_raw, W, self.output_activation, self.D, self.I, self.eps)