Example #1
0
def main(FLAGS):
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    device = 'cuda:0'

    decoder.to(device)
    encoder.to(device)

    tsne = TSNE(2)

    mnist = DataLoader(
        datasets.MNIST(root='mnist',
                       download=True,
                       train=False,
                       transform=transform_config))
    s_dict = {}
    with torch.no_grad():
        for i, (image, label) in enumerate(mnist):
            label = int(label)
            print(i, label)
            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                image.to(device))
            s_dict.setdefault(label, []).append(class_latent_space_1)

    s_all = []
    for label in range(10):
        s_all.extend(s_dict[label])

    s_all = torch.cat(s_all)
    s_all = s_all.view(s_all.shape[0], -1).cpu()

    s_2d = tsne.fit_transform(s_all)

    np.savez('s_2d.npz', s_2d=s_2d)
Example #2
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()
    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer,
                                                       step_size=80,
                                                       gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(
        reverse_cycle_optimizer, step_size=80, gamma=0.1)
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                Variable(X_1))
            style_latent_space_1 = reparameterize(training=True,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) -
                                 style_logvar_1.exp()))
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(
                Variable(X_2))
            style_latent_space_2 = reparameterize(training=True,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) -
                                 style_logvar_2.exp()))
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1,
                                        class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2,
                                        class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (
                reconstruction_error_1 +
                reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2
                                   ) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space),
                                        class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space),
                                        class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(
                style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' +
                      str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' +
                      str(reverse_cycle_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    reverse_cycle_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'KL-Divergence loss',
                kl_divergence_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Reverse cycle loss',
                reverse_cycle_loss.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, _ = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False,
                                                  mu=style_mu_3,
                                                  logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1,
                                          class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3,
                                          class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Example #3
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
Example #4
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size,
                          FLAGS.image_size)
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        X = X.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(mnist) / FLAGS.batch_size)):
            # load a mini-batch
            image_batch, labels_batch = next(loader)

            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X.copy_(image_batch)

            style_mu, style_logvar, class_mu, class_logvar = encoder(
                Variable(X))
            grouped_mu, grouped_logvar = accumulate_group_evidence(
                class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

            # kl-divergence error for style latent space
            style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                 style_logvar.exp()))
            style_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            style_kl_divergence_loss.backward(retain_graph=True)

            # kl-divergence error for class latent space
            class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                 grouped_logvar.exp()))
            class_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            class_kl_divergence_loss.backward(retain_graph=True)

            # reconstruct samples
            """
            sampling from group mu and logvar for each image in mini-batch differently makes
            the decoder consider class latent embeddings as random noise and ignore them 
            """
            style_latent_embeddings = reparameterize(training=True,
                                                     mu=style_mu,
                                                     logvar=style_logvar)
            class_latent_embeddings = group_wise_reparameterize(
                training=True,
                mu=grouped_mu,
                logvar=grouped_logvar,
                labels_batch=labels_batch,
                cuda=FLAGS.cuda)

            reconstructed_images = decoder(style_latent_embeddings,
                                           class_latent_embeddings)

            reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_images, Variable(X))
            reconstruction_error.backward()

            auto_encoder_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Style KL-Divergence loss: ' +
                      str(style_kl_divergence_loss.data.storage().tolist()[0]))
                print('Class KL-Divergence loss: ' +
                      str(class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    style_kl_divergence_loss.data.storage().tolist()[0],
                    class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Style KL-Divergence loss',
                style_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Class KL-Divergence loss',
                class_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)

        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
Example #5
0
    # sigma_q: (batch_size, n_dim, n_frames, n_frames), mu_q: (batch_size, d, nlen)

    l1 = torch.einsum('kij,mkji->mk', sigma_p_inv,
                      sigma_q)  # tr(sigma_p_inv sigma_q)
    l2 = torch.einsum('mki,mki->mk', mu_p - mu_q,
                      torch.einsum('kij,mkj->mki', sigma_p_inv,
                                   mu_p - mu_q))  # <mu_q, sigma_p_inv, mu_q>
    loss = torch.sum(l1 + l2 + torch.log(det_p) - torch.log(det_q), dim=1)
    return loss


if (__name__ == '__main__'):

    # model definition
    encoder = Encoder()
    encoder.apply(weights_init)

    decoder = Decoder()
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if LOAD_SAVED:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', ENCODER_SAVE)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', DECODER_SAVE)))

    # loss definition
    mse_loss = nn.MSELoss()

    # add option to run on gpu
Example #6
0
class BiGAN(object):
    def __init__(self, args):

        self.z_dim = args.z_dim
        self.decay_rate = args.decay_rate
        self.learning_rate = args.learning_rate
        self.model_name = args.model_name
        self.batch_size = args.batch_size

        #initialize networks
        self.Generator = Generator(self.z_dim).cuda()
        self.Encoder = Encoder(self.z_dim).cuda()
        self.Discriminator = Discriminator().cuda()

        #set optimizers for all networks
        self.optimizer_G_E = torch.optim.Adam(
            list(self.Generator.parameters()) +
            list(self.Encoder.parameters()),
            lr=self.learning_rate,
            betas=(0.5, 0.999))

        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=self.learning_rate,
                                            betas=(0.5, 0.999))

        #initialize network weights
        self.Generator.apply(weights_init)
        self.Encoder.apply(weights_init)
        self.Discriminator.apply(weights_init)

    def train(self, data):

        self.Generator.train()
        self.Encoder.train()
        self.Discriminator.train()

        self.optimizer_G_E.zero_grad()
        self.optimizer_D.zero_grad()

        #get fake z_data for generator
        self.z_fake = torch.randn((self.batch_size, self.z_dim))

        #send fake z_data through generator to get fake x_data
        self.x_fake = self.Generator(self.z_fake.detach())

        #send real data through encoder to get real z_data
        self.z_real = self.Encoder(data)

        #send real x and z data into discriminator
        self.out_real = self.Discriminator(data, z_real.detach())

        #send fake x and z data into discriminator
        self.out_fake = self.Discriminator(x_fake.detach(), z_fake.detach())

        #compute discriminator loss
        self.D_loss = nn.BCELoss()

        #compute generator/encoder loss
        self.G_E_loss = nn.BCELoss()

        #compute discriminator gradiants and backpropogate
        self.D_loss.backward()
        self.optimizer_D.step()

        #compute generator/encoder gradiants and backpropogate
        self.G_E_loss.backward()
        self.optimizer_G_E.step()
Example #7
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 784)
    '''
    run on GPU if GPU is available
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder.to(device=device)
    decoder.to(device=device)
    X = X.to(device=device)
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    dirs = os.listdir(os.path.join(os.getcwd(), 'data'))
    print('Loading double multivariate normal time series data...')
    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            print('Running dataset ', dsname)
            ds = DoubleMulNormal(dsname)
            # ds = experiment3(1000, 50, 3)
            loader = cycle(
                DataLoader(ds,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           drop_last=True))

            # initialize summary writer
            writer = SummaryWriter()

            for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
                print()
                print(
                    'Epoch #' + str(epoch) +
                    '........................................................')

                # the total loss at each epoch after running iterations of batches
                total_loss = 0

                for iteration in range(int(len(ds) / FLAGS.batch_size)):
                    # load a mini-batch
                    image_batch, labels_batch = next(loader)

                    # set zero_grad for the optimizer
                    auto_encoder_optimizer.zero_grad()

                    X.copy_(image_batch)

                    style_mu, style_logvar, class_mu, class_logvar = encoder(
                        Variable(X))
                    grouped_mu, grouped_logvar = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, labels_batch,
                        FLAGS.cuda)

                    # kl-divergence error for style latent space
                    style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                         style_logvar.exp()))
                    style_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    style_kl_divergence_loss.backward(retain_graph=True)

                    # kl-divergence error for class latent space
                    class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 *
                        torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                  grouped_logvar.exp()))
                    class_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    class_kl_divergence_loss.backward(retain_graph=True)

                    # reconstruct samples
                    """
                    sampling from group mu and logvar for each image in mini-batch differently makes
                    the decoder consider class latent embeddings as random noise and ignore them 
                    """
                    style_latent_embeddings = reparameterize(
                        training=True, mu=style_mu, logvar=style_logvar)
                    class_latent_embeddings = group_wise_reparameterize(
                        training=True,
                        mu=grouped_mu,
                        logvar=grouped_logvar,
                        labels_batch=labels_batch,
                        cuda=FLAGS.cuda)

                    reconstructed_images = decoder(style_latent_embeddings,
                                                   class_latent_embeddings)

                    reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                        reconstructed_images, Variable(X))
                    reconstruction_error.backward()

                    total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error

                    auto_encoder_optimizer.step()

                    if (iteration + 1) % 50 == 0:
                        print('\tIteration #' + str(iteration))
                        print('Reconstruction loss: ' + str(
                            reconstruction_error.data.storage().tolist()[0]))
                        print('Style KL loss: ' +
                              str(style_kl_divergence_loss.data.storage().
                                  tolist()[0]))
                        print('Class KL loss: ' +
                              str(class_kl_divergence_loss.data.storage().
                                  tolist()[0]))

                    # write to log
                    with open(FLAGS.log_file, 'a') as log:
                        log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                            epoch, iteration,
                            reconstruction_error.data.storage().tolist()[0],
                            style_kl_divergence_loss.data.storage().tolist()
                            [0],
                            class_kl_divergence_loss.data.storage().tolist()
                            [0]))

                    # write to tensorboard
                    writer.add_scalar(
                        'Reconstruction loss',
                        reconstruction_error.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Style KL-Divergence loss',
                        style_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Class KL-Divergence loss',
                        class_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)

                    if epoch == 0 and (iteration + 1) % 50 == 0:
                        torch.save(
                            encoder.state_dict(),
                            os.path.join('checkpoints', 'encoder_' + dsname))
                        torch.save(
                            decoder.state_dict(),
                            os.path.join('checkpoints', 'decoder_' + dsname))

                # save checkpoints after every 10 epochs
                if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join('checkpoints', 'encoder_' + dsname))
                    torch.save(
                        decoder.state_dict(),
                        os.path.join('checkpoints', 'decoder_' + dsname))

                print('Total loss at current epoch: ', total_loss.item())
Example #8
0
class TadGAN(pl.LightningModule):
    def __init__(self,
                 in_size: int,
                 ts_size: int = 100,
                 latent_dim: int = 20,
                 lr: float = 0.0005,
                 weight_decay: float = 1e-6,
                 iterations_critic: int = 5,
                 gamma: float = 10,
                 weighted: bool = True,
                 use_gru=False):
        super(TadGAN, self).__init__()
        self.in_size = in_size
        self.latent_dim = latent_dim
        self.lr = lr
        self.weight_decay = weight_decay
        self.iterations_critic = iterations_critic
        self.gamma = gamma
        self.weighted = weighted

        self.hparams = {
            'lr': self.lr,
            'weight_decay': self.weight_decay,
            'iterations_critic': self.iterations_critic,
            'gamma': self.gamma
        }

        self.encoder = Encoder(in_size,
                               ts_size=ts_size,
                               out_size=self.latent_dim,
                               batch_first=True,
                               use_gru=use_gru)
        self.generator = Generator(use_gru=use_gru)
        self.critic_x = CriticX(in_size=in_size)
        self.critic_z = CriticZ()

        self.encoder.apply(init_weights)
        self.generator.apply(init_weights)
        self.critic_x.apply(init_weights)
        self.critic_z.apply(init_weights)

        if self.logger is not None:
            self.logger.log_hyperparams(self.hparams)

        self.y_hat = []
        self.index = []
        self.critic = []

    def on_fit_start(self):
        if self.logger is not None:
            fig = plot_rws(self.trainer.datamodule.X.cpu().numpy())
            self.logger.experiment.add_figure('Rolling windows/GT',
                                              fig,
                                              global_step=self.global_step)

    def forward(self, x):
        y_hat = self.generator(self.encoder(x))
        critic = self.critic_x(x)

        return y_hat, critic

    def training_step(self, batch, batch_idx, optimizer_idx):
        x = batch[0]
        batch_size = x.size(0)
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        valid = -torch.ones(batch_size, 1, device=self.device)
        fake = torch.ones(batch_size, 1, device=self.device)

        if optimizer_idx == 0:
            if (batch_idx + 1) % self.iterations_critic != 0:
                return None
            z_gen = self.encoder(x)
            x_rec = self.generator(z_gen)
            fake_gen_z = self.critic_z(z_gen)
            fake_gen_x = self.critic_x(self.generator(z))

            wx_loss = self._wasserstein_loss(valid, fake_gen_x)
            wz_loss = self._wasserstein_loss(valid, fake_gen_z)
            rec_loss = F.mse_loss(x_rec, x)
            loss = wx_loss + wz_loss + self.gamma * rec_loss
            vals = {
                'train/Encoder_Generator/loss': loss,
                'train/Encoder_Generator/Wasserstein_x_loss': wx_loss,
                'train/Encoder_Generator/Wasserstein_z_loss': wz_loss,
                'train/Encoder_Generator/Reconstruction_loss': rec_loss
            }
            self.log_dict(vals)
        elif optimizer_idx == 1:
            valid_x = self.critic_x(x)
            x_gen = self.generator(z).detach()
            fake_x = self.critic_x(x_gen)

            wv_loss = self._wasserstein_loss(valid, valid_x)
            wf_loss = self._wasserstein_loss(fake, fake_x)
            gp_loss = self._calculate_gradient_penalty(self.critic_x, x, x_gen)
            loss = wv_loss + wf_loss + self.gamma * gp_loss
            vals = {
                'train/Critic_x/loss': loss,
                'train/Critic_x/Wasserstein_valid_loss': wv_loss,
                'train/Critic_x/Wasserstein_fake_loss': wf_loss,
                'train/Critic_x/gradient_penalty': gp_loss
            }
            self.log_dict(vals)
        elif optimizer_idx == 2:
            valid_z = self.critic_z(z)
            z_gen = self.encoder(x).detach()
            fake_z = self.critic_z(z_gen)

            wv_loss = self._wasserstein_loss(valid, valid_z)
            wf_loss = self._wasserstein_loss(fake, fake_z)
            gp_loss = self._calculate_gradient_penalty(self.critic_z, z, z_gen)
            loss = wv_loss + wf_loss + self.gamma * gp_loss
            vals = {
                'train/Critic_z/loss': loss,
                'train/Critic_z/Wasserstein_valid_loss': wv_loss,
                'train/Critic_z/Wasserstein_fake_loss': wf_loss,
                'train/Critic_z/gradient_penalty': gp_loss
            }
            self.log_dict(vals)
        else:
            raise NotImplementedError()
        return loss

    def validation_step(self, batch, batch_idx):
        x, index = batch
        y_hat, critic = self(x)

        self.y_hat.append(y_hat)
        self.index.append(index)
        self.critic.append(critic)
        return None

    def validation_epoch_end(self, validation_step_outputs):
        if self.logger is None:
            return

        for net_name, net in zip(
            ['Encoder', 'Generator', 'Critic_X', 'Critic_Z'],
            [self.encoder, self.generator, self.critic_x, self.critic_z]):
            for m in net.modules():
                for name, param in m.named_parameters():
                    self.logger.experiment.add_histogram(
                        net_name + '/' + name, param.data)

        y_hat = torch.cat(self.y_hat)
        critic = torch.cat(self.critic)
        index = torch.cat(self.index)

        self.index = []
        self.y_hat = []
        self.critic = []

        n_batches = self.all_gather(y_hat.size(0))
        max_n_batches = n_batches.max()
        if y_hat.size(0) < max_n_batches:
            diff = max_n_batches - y_hat.size(0)
            add_cols = torch.full((diff, *y_hat.shape[1:]),
                                  fill_value=float('nan'),
                                  dtype=y_hat.dtype,
                                  device=y_hat.device)
            y_hat = torch.cat((y_hat, add_cols))
            add_cols = torch.full((diff, *critic.shape[1:]),
                                  fill_value=float('nan'),
                                  dtype=critic.dtype,
                                  device=critic.device)
            critic = torch.cat((critic, add_cols))
            add_cols = torch.full((diff, *index.shape[1:]),
                                  fill_value=-1,
                                  dtype=index.dtype,
                                  device=index.device)
            index = torch.cat((index, add_cols))

        y_hat, critic, index = self.all_gather((y_hat, critic, index))

        if len(y_hat.shape) == 4:
            y_hat = torch.flatten(y_hat, 0, 1)
            critic = torch.flatten(critic, 0, 1)
            index = torch.flatten(index, 0, 1)
        dm = self.trainer.datamodule

        y_shape = y_hat.shape[1:]
        critic_shape = critic.shape[1:]
        index_shape = index.shape[1:]
        mask = ~torch.any(torch.flatten(y_hat, 1, -1).isnan(), dim=1)
        y_hat = y_hat[mask]
        y_hat = y_hat.view(y_hat.size(0), *y_shape)
        critic = critic[mask]
        critic = critic.view(critic.size(0), *critic_shape)
        index = index[mask]

        idx = torch.argsort(index)
        y_hat = y_hat[idx]
        critic = critic[idx]

        assert y_hat.size(0) == critic.size(0)

        max_idx = min(y_hat.shape[0], dm.X.shape[0])

        y_hat = y_hat[:max_idx].cpu().numpy()
        critic = critic = critic[:max_idx].cpu().numpy()
        X = dm.X[:max_idx].cpu().numpy()
        X_index = dm.X_index[:max_idx].cpu().numpy()
        index = dm.index

        self.y_hat = []
        self.critic = []

        # flatten the predicted windows
        # plot the time series
        fig = plot_ts([dm.y, unroll_ts(y_hat)],
                      labels=['original', 'reconstructed'])
        self.logger.experiment.add_figure('TS reconstruction',
                                          fig,
                                          global_step=self.global_step)
        if y_hat.shape[0] == dm.X.shape[0]:
            fig = plot_rws(y_hat)
            self.logger.experiment.add_figure('Rolling windows/Reconstructed',
                                              fig,
                                              global_step=self.global_step)

        errors, true_index, true, predictions = score_anomalies(
            X, y_hat, critic, X_index, rec_error_type='dtw', comb='mult')
        anomalies = find_anomalies(errors,
                                   index,
                                   window_size_portion=0.33,
                                   window_step_size_portion=0.1,
                                   fixed_threshold=True)
        if anomalies.size == 0:
            anomalies = pd.DataFrame(columns=['start', 'end', 'score'])
        else:
            anomalies = pd.DataFrame(anomalies,
                                     columns=['start', 'end', 'score'])

        gt_anomalies = dm.anomalies
        if gt_anomalies is not None:
            fig = plot(dm.df, [('anomalies', anomalies),
                               ('gt_anomalies', gt_anomalies)])
        else:
            fig = plot(dm.df, [('anomalies', anomalies)])
        self.logger.experiment.add_figure('AD output',
                                          fig,
                                          global_step=self.global_step)

        metric_logged = False
        if not anomalies.empty:
            fig = plot_table_anomalies(anomalies)
            self.logger.experiment.add_figure('Anomaly table',
                                              fig,
                                              global_step=self.global_step)

            if gt_anomalies is not None:
                # Workaround to dispay PR Curve
                if self.weighted:
                    labels, preds, weights = contextual_prepare_weighted(
                        gt_anomalies, anomalies, data=dm.df)
                    self.logger.experiment.add_pr_curve(
                        'PR Curve',
                        np.array(labels),
                        np.array(preds),
                        weights=np.array(weights),
                        global_step=self.global_step)

                acc = contextual_accuracy(gt_anomalies,
                                          anomalies,
                                          data=dm.df,
                                          weighted=self.weighted)
                prec = contextual_precision(gt_anomalies,
                                            anomalies,
                                            data=dm.df,
                                            weighted=self.weighted)
                recall = contextual_recall(gt_anomalies,
                                           anomalies,
                                           data=dm.df,
                                           weighted=self.weighted)
                f1 = contextual_f1_score(gt_anomalies,
                                         anomalies,
                                         data=dm.df,
                                         weighted=self.weighted,
                                         beta=2)
                vals = {
                    'Accuracy': acc,
                    'Precision': prec,
                    'Recall': recall,
                    'F1': f1
                }
                self.log_dict(vals)
                metric_logged = True
        if not metric_logged:
            vals = {'Accuracy': 0, 'Precision': 0, 'Recall': 0, 'F1': 0}
            self.log_dict(vals)

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        optimizer.zero_grad(set_to_none=True)

    def configure_optimizers(self):
        params = [self.encoder.parameters(), self.generator.parameters()]
        e_g_opt = Adam(itertools.chain(*params),
                       lr=self.lr,
                       weight_decay=self.weight_decay)
        c_x_opt = Adam(self.critic_x.parameters(),
                       lr=self.lr,
                       weight_decay=self.weight_decay)
        c_z_opt = Adam(self.critic_z.parameters(),
                       lr=self.lr,
                       weight_decay=self.weight_decay)

        return [e_g_opt, c_x_opt, c_z_opt]

    @staticmethod
    def _wasserstein_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
        return torch.mean(y_true * y_pred)

    def _calculate_gradient_penalty(self, model: torch.nn.Module,
                                    y_true: torch.Tensor,
                                    y_pred: torch.Tensor):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake data
        alpha = torch.randn((y_true.size(0), 1, 1), device=self.device)
        # Get random interpolation between real and fake data
        interpolates = (alpha * y_true +
                        ((1 - alpha) * y_pred)).requires_grad_(True)

        model_interpolates = model(interpolates)
        grad_outputs = torch.ones(model_interpolates.size(),
                                  device=self.device,
                                  requires_grad=False)

        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=model_interpolates,
            inputs=interpolates,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1)**2)
        return gradient_penalty

if (__name__ == '__main__'):

    # model definition
    BATCH_SIZE = 1

    dataset = load_dataset()
    loader = cycle(
        DataLoader(dataset,
                   batch_size=BATCH_SIZE,
                   shuffle=True,
                   drop_last=True))

    encoder = Encoder()
    encoder.apply(weights_init)

    decoder = Decoder()
    decoder.apply(weights_init)

    encoder.load_state_dict(
        torch.load(os.path.join('checkpoints', ENCODER_SAVE)))
    decoder.load_state_dict(
        torch.load(os.path.join('checkpoints', DECODER_SAVE)))

    encoder.eval()
    decoder.eval()

    prediction_model = Prediction_Model()
    prediction_model.apply(weights_init)
Example #10
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.decoder_save)))
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    savedir = 'checkpoints_%d' % (FLAGS.batch_size)
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    # Creating data indices for training and validation splits:
    dataset_size = len(mnist)
    indices = list(range(dataset_size))
    split = 10000
    np.random.seed(0)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_mnist, val_mnist = torch.utils.data.random_split(
        mnist, [dataset_size - split, split])

    # Creating PT data samplers and loaders:
    weights_train = torch.ones(len(mnist))
    weights_test = torch.ones(len(mnist))
    weights_train[val_mnist.indices] = 0
    weights_test[train_mnist.indices] = 0
    counts = torch.zeros(10)
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(i)
        counts[i] = idx_label.sum()
    max = float(counts.max())
    sum_counts = float(counts.sum())
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(
            i).nonzero().squeeze()
        weights_train[train_mnist.indices[idx_label]] = (sum_counts /
                                                         counts[i])

    train_sampler = SubsetRandomSampler(train_mnist.indices)
    valid_sampler = SubsetRandomSampler(val_mnist.indices)
    kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {}
    loader = DataLoader(mnist,
                        batch_size=FLAGS.batch_size,
                        sampler=train_sampler,
                        **kwargs)
    valid_loader = DataLoader(mnist,
                              batch_size=FLAGS.batch_size,
                              sampler=valid_sampler,
                              **kwargs)
    monitor = torch.zeros(FLAGS.end_epoch - FLAGS.start_epoch, 4)
    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )
        elbo_epoch = 0
        term1_epoch = 0
        term2_epoch = 0
        term3_epoch = 0
        for it, (image_batch, labels_batch) in enumerate(loader):
            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X = image_batch.cuda().detach().clone()
            elbo, reconstruction_proba, style_kl_divergence_loss, class_kl_divergence_loss = process(
                FLAGS, X, labels_batch, encoder, decoder)
            (-elbo).backward()
            auto_encoder_optimizer.step()
            elbo_epoch += elbo
            term1_epoch += reconstruction_proba
            term2_epoch += style_kl_divergence_loss
            term3_epoch += class_kl_divergence_loss

        print("Elbo epoch %.2f" % (elbo_epoch / (it + 1)))
        print("Rec. Proba %.2f" % (term1_epoch / (it + 1)))
        print("KL style %.2f" % (term2_epoch / (it + 1)))
        print("KL content %.2f" % (term3_epoch / (it + 1)))
        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            monitor[epoch, :] = eval(FLAGS, valid_loader, encoder, decoder)
            torch.save(
                encoder.state_dict(),
                os.path.join(savedir, FLAGS.encoder_save + '_e%d' % epoch))
            torch.save(
                decoder.state_dict(),
                os.path.join(savedir, FLAGS.decoder_save + '_e%d' % epoch))
            print("VAL elbo %.2f" % (monitor[epoch, 0]))
            print("VAL Rec. Proba %.2f" % (monitor[epoch, 1]))
            print("VAL KL style %.2f" % (monitor[epoch, 2]))
            print("VAL KL content %.2f" % (monitor[epoch, 3]))

            torch.save(monitor, os.path.join(savedir, 'monitor_e%d' % epoch))
Example #11
0
    del val_dset
    del val_data
    vutils.save_image(fixed_reals,
                      join(sample_path, '{:03d}_real.jpg'.format(0)),
                      nrow=4,
                      padding=0,
                      normalize=True,
                      range=(-1., 1.))
    vutils.save_image(fixed_annos.float() / n_classes,
                      join(sample_path, '{:03d}_anno.jpg'.format(0)),
                      nrow=4,
                      padding=0)

    # Models
    E = Encoder().to(device)
    E.apply(init_weights)
    # summary(E, (3, 256, 256), device=device)
    G = Generator(n_classes).to(device)
    G.apply(init_weights)
    # summary(G, [(256,), (10, 256, 256)], device=device)
    D = Discriminator(n_classes).to(device)
    D.apply(init_weights)
    # summary(D, (13, 256, 256), device=device)
    vgg = VGG().to(device)

    if args.multi_gpu:
        E = nn.DataParallel(E)
        G = nn.DataParallel(G)
        # G = convert_model(G)
        D = nn.DataParallel(D)
        VGG = nn.DataParallel(VGG)
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))

    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
        style_latent_space = style_latent_space.cuda()

    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t')
            log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1))
            style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            _, __, class_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(style_1, class_1)
            reconstructed_X_2 = decoder(style_1, class_2)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2
            kl_divergence_error = kl_divergence_loss_1

            auto_encoder_optimizer.step()

            # B. run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, __ = next(loader)
                image_batch_3, _, __ = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

                kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
                kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
                kl_divergence_loss_1.backward(retain_graph=True)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                output_1 = discriminator(Variable(X_3), reconstructed_X_1_3)

                generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels))
                generator_error_1.backward(retain_graph=True)

                style_latent_space.normal_(0., 1.)
                reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3)

                output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3)

                generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels))
                generator_error_2.backward()

                generator_error = generator_error_1 + generator_error_2
                kl_divergence_error += kl_divergence_loss_1

                generator_optimizer.step()

            # C. run the discriminator
            for i in range(FLAGS.discriminator_times):

                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, _, __ = next(loader)
                image_batch_2, image_batch_3, _ = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)
                X_3.copy_(image_batch_3)

                real_output = discriminator(Variable(X_2), Variable(X_3))

                discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                fake_output = discriminator(Variable(X_3), reconstructed_X_1_3)

                discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat((real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()
                                          ).sum().item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy:
                    discriminator_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))

                print('')
                print('Generator loss: ' + str(generator_error.data.storage().tolist()[0]))
                print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format(
                    epoch,
                    iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy
                ))

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100,
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    encoder.apply(weights_init)

    decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))
    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    discriminator_optimizer = optim.Adam(list(discriminator.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    generator_optimizer = optim.Adam(list(encoder.parameters()) +
                                     list(decoder.parameters()),
                                     lr=FLAGS.initial_learning_rate,
                                     betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\t')
            log.write(
                'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_2 = decoder(nv_2, nc_1)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2

            if FLAGS.train_auto_encoder:
                auto_encoder_optimizer.step()

            # B. run the adversarial part of the architecture

            # B. a) run the discriminator
            for i in range(FLAGS.discriminator_times):
                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, image_batch_2, labels_batch_1 = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)

                real_output = discriminator(Variable(X_1), Variable(X_2))

                discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss(
                    real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                image_batch_3, _, labels_batch_3 = next(loader)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                fake_output = discriminator(Variable(X_1), reconstructed_X_3_1)

                discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss(
                    fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(
                    FLAGS.batch_size), torch.zeros(FLAGS.batch_size)),
                                               dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat(
                    (real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(
                    discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data
                                          == target_true_labels.long()).sum(
                                          ).item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator:
                    discriminator_optimizer.step()

            # B. b) run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, labels_batch_1 = next(loader)
                image_batch_3, __, labels_batch_3 = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                output = discriminator(Variable(X_1), reconstructed_X_3_1)

                generator_error = FLAGS.gen_coef * cross_entropy_loss(
                    output, Variable(domain_labels))
                generator_error.backward()

                if FLAGS.train_generator:
                    generator_optimizer.step()

            # print progress after 10 iterations
            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Generator loss: ' +
                      str(generator_error.data.storage().tolist()[0]))

                print('')
                print('Discriminator loss: ' +
                      str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Generator loss',
                generator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Discriminator loss',
                discriminator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(),
                       os.path.join('checkpoints', FLAGS.discriminator_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))
            nv_3, nc_3 = encoder(Variable(X_3))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_3_2 = decoder(nv_3, nc_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            # save cross reconstructed batch
            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Example #14
0
def main(args, dataloader):
    # define the networks
    netG = Generator(ngf=args.ngf, nz=args.nz, nc=args.nc).cuda()
    netG.apply(weight_init)
    print(netG)

    netD = Discriminator(ndf=args.ndf, nc=args.nc, nz=args.nz).cuda()
    netD.apply(weight_init)
    print(netD)

    netE = Encoder(nc=args.nc, ngf=args.ngf, nz=args.nz).cuda()
    netE.apply(weight_init)
    print(netE)

    # define the loss criterion
    criterion = nn.BCELoss()

    # define the ground truth labels.
    real_label = 1  # for the real pair
    fake_label = 0  # for the fake pair

    # define the optimizers, one for each network
    netD_optimizer = optim.Adam(netD.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
    netG_optimizer = optim.Adam([{
        'params': netG.parameters()
    }, {
        'params': netE.parameters()
    }],
                                lr=args.lr,
                                betas=(0.5, 0.999))

    # Training loop
    iters = 0

    for epoch in range(args.num_epochs):
        # iterate through the dataloader
        for i, data in enumerate(dataloader, 0):
            real_images = data[0].cuda()
            bs = real_images.shape[0]

            noise1 = torch.Tensor(real_images.size()).normal_(
                0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()
            noise2 = torch.Tensor(real_images.size()).normal_(
                0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()

            # get the output from the encoder
            z_real = netE(real_images).view(bs, -1)
            mu, sigma = z_real[:, :args.nz], z_real[:, args.nz:]
            log_sigma = torch.exp(sigma)
            epsilon = torch.randn(bs, args.nz).cuda()
            # reparameterization trick
            output_z = mu + epsilon * log_sigma
            output_z = output_z.view(bs, -1, 1, 1)

            # get the output from the generator
            z_fake = torch.randn(bs, args.nz, 1, 1).cuda()
            d_fake = netG(z_fake)

            # get the output from the discriminator for the real pair
            out_real_pair = netD(real_images + noise1, output_z)

            # get the output from the discriminator for the fake pair
            out_fake_pair = netD(d_fake + noise2, z_fake)

            real_labels = torch.full((bs, ), real_label).cuda()
            fake_labels = torch.full((bs, ), fake_label).cuda()

            # compute the losses
            d_loss = criterion(out_real_pair, real_labels) + criterion(
                out_fake_pair, fake_labels)
            g_loss = criterion(out_real_pair, fake_labels) + criterion(
                out_fake_pair, real_labels)

            # update weights
            if g_loss.item() < 3.5:
                netD_optimizer.zero_grad()
                d_loss.backward(retain_graph=True)
                netD_optimizer.step()

            netG_optimizer.zero_grad()
            g_loss.backward()
            netG_optimizer.step()

            # print the training losses
            if iters % 10 == 0:
                print(
                    '[%3d/%d][%3d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x, z): %.4f\tD(G(z), z): %.4f'
                    %
                    (epoch, args.num_epochs, i, len(dataloader), d_loss.item(),
                     g_loss.item(), out_real_pair.mean().item(),
                     out_fake_pair.mean().item()))

            # visualize the samples generated by the G.
            if iters % 500 == 0:
                out_dir = os.path.join(args.log_dir, args.run_name, 'out/')
                os.makedirs(out_dir, exist_ok=True)
                save_image(d_fake.cpu()[:64, ],
                           os.path.join(out_dir,
                                        str(iters).zfill(7) + '.png'),
                           nrow=8,
                           normalize=True)
                # save reconstructions
                recons_dir = os.path.join(args.log_dir, args.run_name,
                                          'recons/')
                os.makedirs(recons_dir, exist_ok=True)
                save_image(torch.cat(
                    [real_images.cpu()[:8],
                     d_fake.cpu()[:8, ]], dim=3),
                           os.path.join(recons_dir,
                                        str(iters).zfill(7) + '.png'),
                           nrow=1,
                           normalize=True)

            iters += 1

        # save weights
        save_dir = os.path.join(args.log_dir, args.run_name, 'weights')
        os.makedirs(save_dir, exist_ok=True)
        save_weights(netG, './%s/netG.pth' % (save_dir))
        save_weights(netE, './%s/netE.pth' % (save_dir))