def compute_losses(self):
        """Compute losses."""
        self.reconstruction_loss_a = losses.reconstruction_loss(
            real_images=self.input_a, generated_images=self.ae_images_a)
        self.reconstruction_loss_b = losses.reconstruction_loss(
            real_images=self.input_b, generated_images=self.ae_images_b)

        self.lsgan_loss_fake_a = losses.lsgan_loss_generator(
            self.prob_fake_a_is_real)
        self.lsgan_loss_fake_b = losses.lsgan_loss_generator(
            self.prob_fake_b_is_real)

        self.cycle_consistency_loss_a = losses.cycle_consistency_loss(
            real_images=self.input_a, generated_images=self.cycle_images_a)
        self.cycle_consistency_loss_b = losses.cycle_consistency_loss(
            real_images=self.input_b, generated_images=self.cycle_images_b)

        self.g_loss = self._rec_lambda_a * self.reconstruction_loss_a + \
            self._rec_lambda_b * self.reconstruction_loss_b + \
            self._cycle_lambda_a * self.cycle_consistency_loss_a + \
            self._cycle_lambda_b * self.cycle_consistency_loss_b + \
            self._lsgan_lambda_a * self.lsgan_loss_fake_a + \
            self._lsgan_lambda_b * self.lsgan_loss_fake_b

        self.d_loss_A = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real)
        self.d_loss_B = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real)

        self.model_vars = tf.trainable_variables()

        d_a_vars = [
            var for var in self.model_vars
            if 'd1' in var.name or 'd_shared' in var.name
        ]
        d_b_vars = [
            var for var in self.model_vars
            if 'd2' in var.name or 'd_shared' in var.name
        ]
        g_vars = [
            var for var in self.model_vars if 'ae1' in var.name
            or 'ae2' in var.name or 'ae_shared' in var.name
        ]

        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)

        self.d_A_trainer = optimizer.minimize(self.d_loss_A, var_list=d_a_vars)
        self.d_B_trainer = optimizer.minimize(self.d_loss_B, var_list=d_b_vars)
        self.g_trainer = optimizer.minimize(self.g_loss, var_list=g_vars)

        self.create_summaries()
Ejemplo n.º 2
0
    def _train_generator_step(self, real_img, noise, img_from_prev_scale,
                              rec_from_prev_scale):
        # Convert images to batch of images with one elements (to be fed to the models and the loss functions)
        # batch_fake_img = fake_img[np.newaxis, :, :, :]

        with tf.GradientTape() as tape:
            fake_img = self._generate_image_for_training(
                noise, img_from_prev_scale)
            batch_fake_img = fake_img[np.newaxis, :, :, :]

            # Get the discriminator logits for fake patches
            fake_logits = self.critic(batch_fake_img, training=True)
            # Calculate the generator adversarial loss
            adv_loss = generator_wass_loss(fake_logits)
            # Calculate the reconstruction loss and update the noise sigma with it
            reconstructed = self._reconstruct_image_for_training(
                rec_from_prev_scale)
            rec_loss = reconstruction_loss(reconstructed, real_img)

            loss = adv_loss + self.rec_loss_weight * rec_loss
        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables))

        return adv_loss, self.rec_loss_weight * rec_loss
Ejemplo n.º 3
0
def test_step(model, config, inputs, logger=None):
    """Reconstruction Test step during training."""
    outputs = inputs.clone().detach()

    with torch.no_grad():
        (preds, priors, posteriors), stored_vars = model(
            inputs,
            config,
            False,
        )

        # Accumulate preds and select targets
        targets = outputs[:, config['n_ctx']:]

        # Compute the reconstruction and prior loss
        loss_rec = losses.reconstruction_loss(config, preds, targets)
        if config['beta'] > 0:
            loss_prior = losses.kl_loss(config, priors, posteriors)
            loss = loss_rec + config['beta'] * loss_prior
        else:
            loss = loss_rec

    # Logs
    if logger is not None:
        logger.scalar('test_loss_rec', loss_rec.item())
        logger.scalar('test_loss', loss.item())
        if config['beta'] > 0:
            logger.scalar('test_loss_prior', loss_prior.item())
Ejemplo n.º 4
0
    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss_val = reconstruction_loss(data, reconstruction)
            kl_loss = kl_divergence(z_mean, z_var)
            total_loss = reconstruction_loss_val + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss_val)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            'loss': self.total_loss_tracker.result(),
            'reconstruction_loss': self.reconstruction_loss_tracker.result(),
            'kl_loss': self.kl_loss_tracker.result()
        }
Ejemplo n.º 5
0
def train_step(model, config, inputs, optimizer, batch_idx, logger=None):
    """Training step for the model."""

    outputs = inputs.clone().detach()

    # Forward pass
    (preds, priors, posteriors), stored_vars = model(inputs, config, False)

    # Accumulate preds and select targets
    targets = outputs[:, config['n_ctx']:]

    # Compute the reconstruction loss
    loss_rec = losses.reconstruction_loss(config, preds, targets)

    # Compute the prior loss
    if config['beta'] > 0:
        loss_prior = losses.kl_loss(config, priors, posteriors)
        loss = loss_rec + config['beta'] * loss_prior
    else:
        loss_prior = 0.
        loss = loss_rec

    # Backward pass and optimizer step
    optimizer.zero_grad()
    if config['apex']:
        from apex import amp
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
    else:
        loss.backward()
    optimizer.step()

    # Logs
    if logger is not None:
        logger.scalar('train_loss_rec', loss_rec.item())
        logger.scalar('train_loss', loss.item())
        if config['beta'] > 0:
            logger.scalar('train_loss_prior', loss_prior.item())

    return preds, targets, priors, posteriors, loss_rec, loss_prior, loss, stored_vars
Ejemplo n.º 6
0
    def compute_losses(self, img_real, label_org, label_trg):
        # discriminator loss
        out_src_real, out_cls_real = self.discriminator(img_real)
        img_fake = self.generator([img_real, label_trg])
        out_src_fake, out_cls_fake = self.discriminator(img_fake)

        d_adv_loss = -adversarial_loss(out_src_fake, out_src_real)
        d_cls_loss = classification_loss(label_org, out_cls_real)

        alpha = tf.random.uniform(shape=[img_real.shape.as_list()[0], 1, 1, 1])
        img_hat = alpha * img_real + (1 - alpha) * img_fake
        d_gp_loss = self.gradient_penalty(img_hat)

        self.d_loss = d_adv_loss + self.lambda_cls * d_cls_loss + self.lambda_gp * d_gp_loss

        # generator loss
        g_adv_loss = adversarial_loss(out_src_fake)
        g_cls_loss = classification_loss(label_trg, out_cls_fake)

        img_rec = self.generator([img_fake, label_org])
        g_rec_loss = reconstruction_loss(img_real, img_rec)

        self.g_loss = g_adv_loss + self.lambda_cls * g_cls_loss + self.lambda_rec * g_rec_loss
Ejemplo n.º 7
0
def test_fashion_mnist():
    CUDA = torch.cuda.is_available()
    net = CapsNet(1, 6 * 6 * 32)
    if CUDA:
        net.cuda()
    print(net)
    print("# parameters: ", sum(param.numel() for param in net.parameters()))

    transform = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.FashionMNIST(root='./data/fashion',
                                                 train=True,
                                                 download=True,
                                                 transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=32,
                                              shuffle=True)

    testset = torchvision.datasets.FashionMNIST(root='./data/fashion',
                                                train=False,
                                                download=True,
                                                transform=transform)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=32,
                                             shuffle=False)
    optimizer = Adam(net.parameters())

    n_epochs = 30
    print_every = 200 if CUDA else 2

    for epoch in range(n_epochs):
        train_acc = 0.
        time_start = time.time()
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            labels_one_hot = torch.eye(10).index_select(dim=0, index=labels)
            inputs, labels_one_hot, labels = Variable(inputs), Variable(
                labels_one_hot), Variable(labels)
            if CUDA:
                inputs, labels_one_hot, labels = inputs.cuda(
                ), labels_one_hot.cuda(), labels.cuda()
            optimizer.zero_grad()
            class_probs, recons = net(inputs, labels)
            acc = torch.mean((labels == torch.max(class_probs,
                                                  -1)[1]).double())
            train_acc += acc.data[0]
            loss = (margin_loss(class_probs, labels_one_hot) +
                    0.0005 * reconstruction_loss(recons, inputs))
            loss.backward()
            optimizer.step()
            if (i + 1) % print_every == 0:
                print(
                    '[epoch {}/{}, batch {}] train_loss: {:.5f}, train_acc: {:.5f}'
                    .format(epoch + 1, n_epochs, i + 1, loss.data[0],
                            acc.data[0]))
        test_acc = 0.
        for j, data in enumerate(testloader, 0):
            inputs, labels = data
            labels_one_hot = torch.eye(10).index_select(dim=0, index=labels)
            inputs, labels_one_hot, labels = Variable(inputs), Variable(
                labels_one_hot), Variable(labels)
            if CUDA:
                inputs, labels_one_hot, labels = inputs.cuda(
                ), labels_one_hot.cuda(), labels.cuda()
            class_probs, recons = net(inputs)
            acc = torch.mean((labels == torch.max(class_probs,
                                                  -1)[1]).double())
            test_acc += acc.data[0]
        print(
            '[epoch {}/{} done in {:.2f}s] train_acc: {:.5f} test_acc: {:.5f}'.
            format(epoch + 1, n_epochs, (time.time() - time_start),
                   train_acc / (i + 1), test_acc / (j + 1)))
Ejemplo n.º 8
0
def main(args=None):
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')
    # shutil.rmtree("./data")
    G = Generator()
    D = Discriminator()
    optimizerG = torch.optim.Adam(G.parameters(), lr=1e-4)
    optimizerD = torch.optim.RMSprop(D.parameters(), lr=1e-4)

    win = lambda x: torch.ones(x) if not args.hann else torch.hann_window(x)

    dataset = se_dataset.SEDataset(args.clean_dir, args.noisy_dir,
                                   0.95,
                                   cache_dir=args.cache_dir,
                                   slice_size=args.slice_size,
                                   max_samples=1000)
    dloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)

    G = Generator().to(device)
    D = Discriminator().to(device)
    optimizerG = torch.optim.Adam(G.parameters(), lr=1e-4)
    optimizerD = torch.optim.RMSprop(D.parameters(), lr=1e-4)


    torch.autograd.set_detect_anomaly(True)

    G.to(device)
    D.to(device)

    g_loss = []
    d_loss = []

    writer = SummaryWriter("./logs")
    train = False
    if train:
        for epoch in range(args.num_epochs):
            print("EPOCH", epoch)
            if epoch == (args.num_epochs // 3):
                optimizerD.param_groups[0]['lr'] = optimizerD.param_groups[0]['lr'] / 10
                optimizerG.param_groups[0]['lr'] = optimizerG.param_groups[0]['lr'] / 10
            
            for bidx, batch in tqdm(enumerate(dloader, 1), total=len(dloader)):
                uttname, clean, noisy, slice_idx = batch
                clean, noisy = clean.to(device), noisy.to(device)

                # Get real data
                if args.task == 'sr':
                    real_full = full_lps(clean).to(device)
                    lf = lowres_lps(clean, args.rate_div).to(device)
                else:
                    real_full = full_lps(clean).detach().clone().to(device)
                    lf = full_lps(noisy)[:,:129,:].detach().clone().to(device)

                # Update D
                if epoch >= (args.num_epochs // 3):
                    for p in D.parameters():
                        p.requires_grad = True
                    fake_hf = G(lf).to(device)
                    fake_full = torch.cat([lf, fake_hf], 1)
                    fake_logit, fake_prob = D(fake_full)
                    real_logit, real_prob = D(real_full)

                    optimizerD.zero_grad()
                    gan_loss_d = disc_loss(fake_prob, real_prob)
                    reg_d = regularization_term(fake_prob, real_prob, fake_logit, real_logit)
                    lossD = gan_loss_d + reg_d
                    lossD.backward()
                    writer.add_scalar("loss/D_loss", lossD)
                    writer.add_scalar("loss/D_loss_reg", reg_d)
                    writer.add_scalar("loss/D_loss_gan", gan_loss_d)
                    optimizerD.step()
                    d_loss.append(lossD.item())

                # Update G
                for p in D.parameters():
                    p.requires_grad = False

                fake_hf = G(lf)
                fake_full = torch.cat([lf, fake_hf], 1)
                fake_logit, fake_prob = D(fake_full)
                real_logit, real_prob = D(real_full)
                gan = None
                if epoch >= (args.num_epochs // 3):
                    gan = gen_loss(fake_prob)
                    rec_loss = reconstruction_loss(real_full, fake_full)
                    lossG = args.lambd * rec_loss - gan
                else:
                    rec_loss = reconstruction_loss(real_full, fake_full)
                    lossG = args.lambd * rec_loss
                writer.add_scalar("loss/G_loss", lossG)
                writer.add_scalar("rec_loss/rec_loss", rec_loss)
                lossG.backward()
                optimizerG.step()
                g_loss.append(lossG.item())
            with open("result.pth", "wb") as f:
                torch.save({
                    "g_state_dict": G.state_dict(),
                    "d_state_dict": D.state_dict(),
                }, f)
    else:
        with open("./result.pth", "rb") as f:
            G.load_state_dict(torch.load(f)["g_state_dict"])
        for bidx, batch in tqdm(enumerate(dloader, 1), total=len(dloader)):
            uttname, clean, noisy, slice_idx = batch
            clean, noisy = clean.to(device), noisy.to(device)
            real_full = full_lps(clean).to(device)
            lf = lowres_lps(clean, args.rate_div).to(device)
            fake_hf = G(lf)
            fake_full = torch.cat([lf, fake_hf], 1)
            break
Ejemplo n.º 9
0
def test_cifar10():
    CUDA = torch.cuda.is_available()
    net = CapsNet(8 * 8 * 32, [3, 32, 32])
    if CUDA:
        net.cuda()
    print(net)
    print("# parameters: ", sum(param.numel() for param in net.parameters()))

    transform = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.CIFAR10(root='./data/cifar10',
                                            train=True,
                                            download=True,
                                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=32,
                                              shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data/cifar10',
                                           train=False,
                                           download=True,
                                           transform=transform)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=32,
                                             shuffle=False)
    optimizer = Adam(net.parameters())

    n_epochs = 200
    print_every = 200 if CUDA else 2
    display_every = 10

    for epoch in range(n_epochs):
        train_acc = 0.
        time_start = time.time()
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            masked_inputs = put_mask(inputs)
            labels_one_hot = torch.eye(10).index_select(dim=0, index=labels)
            if CUDA:
                masked_inputs, inputs, labels_one_hot, labels = masked_inputs.cuda(
                ), inputs.cuda(), labels_one_hot.cuda(), labels.cuda()
            optimizer.zero_grad()
            class_probs, recons = net(masked_inputs, labels)
            acc = torch.mean((labels == torch.max(class_probs,
                                                  -1)[1]).double())
            train_acc += acc.data.item()
            loss = (margin_loss(class_probs, labels_one_hot) +
                    0.0005 * reconstruction_loss(recons, inputs))
            loss.backward()
            optimizer.step()
            if (i + 1) % print_every == 0:
                print(
                    '[epoch {}/{}, batch {}] train_loss: {:.5f}, train_acc: {:.5f}'
                    .format(epoch + 1, n_epochs, i + 1, loss.data.item(),
                            acc.data.item()))
        test_acc = 0.
        for j, data in enumerate(testloader, 0):
            inputs, labels = data
            masked_inputs = put_mask(inputs)
            labels_one_hot = torch.eye(10).index_select(dim=0, index=labels)
            if CUDA:
                masked_inputs, inputs, labels_one_hot, labels = masked_inputs.cuda(
                ), inputs.cuda(), labels_one_hot.cuda(), labels.cuda()
            class_probs, recons = net(masked_inputs)
            if (j + 1) % display_every == 0:
                display(inputs[0].cpu(), masked_inputs[0].cpu(),
                        recons[0].cpu().detach())
            acc = torch.mean((labels == torch.max(class_probs,
                                                  -1)[1]).double())
            test_acc += acc.data.item()
        print(
            '[epoch {}/{} done in {:.2f}s] train_acc: {:.5f} test_acc: {:.5f}'.
            format(epoch + 1, n_epochs, (time.time() - time_start),
                   train_acc / (i + 1), test_acc / (j + 1)))
Ejemplo n.º 10
0
 def reconstruction_loss(X, x, x_raw, W):
     return losses.reconstruction_loss(X, x, x_raw, W, self.output_activation, self.D, self.I, self.eps)