class LSGANs_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(LSGANs_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.encoder = Encoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.decoder = Decoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.dis_a = Discriminator()
        self.dis_b = Discriminator()
        self.interp_net_ab = Interpolator()
        self.interp_net_ba = Interpolator()
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        enc_params = list(self.encoder.parameters())
        dec_params = list(self.decoder.parameters())
        dis_a_params = list(self.dis_a.parameters())
        dis_b_params = list(self.dis_b.parameters())
        interperlator_ab_params = list(self.interp_net_ab.parameters())
        interperlator_ba_params = list(self.interp_net_ba.parameters())

        self.enc_opt = torch.optim.Adam(
            [p for p in enc_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dec_opt = torch.optim.Adam(
            [p for p in dec_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_a_opt = torch.optim.Adam(
            [p for p in dis_a_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_b_opt = torch.optim.Adam(
            [p for p in dis_b_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ab_opt = torch.optim.Adam(
            [p for p in interperlator_ab_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ba_opt = torch.optim.Adam(
            [p for p in interperlator_ba_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters)
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters)
        self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                 hyperparameters)
        self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                 hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.total_loss = 0
        self.best_iter = 0
        self.perceptural_loss = Perceptural_loss()

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()

        c_a, s_a_fake = self.encoder(x_a)
        c_b, s_b_fake = self.encoder(x_b)

        # decode (cross domain)
        s_ab_interp = self.interp_net_ab(s_a_fake, s_b_fake, self.v)
        s_ba_interp = self.interp_net_ba(s_b_fake, s_a_fake, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = selfdecoder(c_a, s_b_interp)
        self.train()
        return x_ab, x_ba

    def zero_grad(self):
        self.dis_a_opt.zero_grad()
        self.dis_b_opt.zero_grad()
        self.dec_opt.zero_grad()
        self.enc_opt.zero_grad()
        self.interp_ab_opt.zero_grad()
        self.interp_ba_opt.zero_grad()

    def dis_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        x_a_feature = self.dis_a(x_a)
        x_ba_feature = self.dis_a(x_ba)
        x_b_feature = self.dis_b(x_b)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_dis_a = (x_ba_feature - x_a_feature).mean()
        self.loss_dis_b = (x_ab_feature - x_b_feature).mean()

        # gradient penality
        self.loss_dis_a_gp = self.dis_a.calculate_gradient_penalty(x_ba, x_a)
        self.loss_dis_b_gp = self.dis_b.calculate_gradient_penalty(x_ab, x_b)


        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b + \
                              hyperparameters['gan_w'] * self.loss_dis_a_gp + \
                              hyperparameters['gan_w'] * self.loss_dis_b_gp

        self.loss_dis_total.backward()
        self.total_loss += self.loss_dis_total.item()
        self.dis_a_opt.step()
        self.dis_b_opt.step()

    def gen_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (within domain)
        x_a_recon = self.decoder(c_a, s_a)
        x_b_recon = self.decoder(c_b, s_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        # encode again
        c_b_recon, s_a_recon = self.encoder(x_ba)
        c_a_recon, s_b_recon = self.encoder(x_ab)

        # decode again
        x_aa = self.decoder(
            c_a_recon, s_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bb = self.decoder(
            c_b_recon, s_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aa, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bb, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # perceptual loss
        self.loss_gen_vgg_a = self.perceptural_loss(
            x_a_recon, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.perceptural_loss(
            x_b_recon, x_b) if hyperparameters['vgg_w'] > 0 else 0

        self.loss_gen_vgg_aa = self.perceptural_loss(
            x_aa, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_bb = self.perceptural_loss(
            x_bb, x_b) if hyperparameters['vgg_w'] > 0 else 0

        # GAN loss
        x_ba_feature = self.dis_a(x_ba)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_gen_adv_a = -x_ba_feature.mean()
        self.loss_gen_adv_b = -x_ab_feature.mean()

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_aa + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_bb + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        self.loss_gen_total.backward()
        self.total_loss += self.loss_gen_total.item()
        self.dec_opt.step()
        self.enc_opt.step()
        self.interp_ab_opt.step()
        self.interp_ba_opt.step()

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ab, x_ba, x_aa, x_bb = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a = self.encoder(x_a[i].unsqueeze(0))
            c_b, s_b = self.encoder(x_b[i].unsqueeze(0))
            x_a_recon.append(self.decoder(c_a, s_a))
            x_b_recon.append(self.decoder(c_b, s_b))

            self.v = torch.ones(s_a.size())
            s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
            s_b_interp = self.interp_net_ab(s_a, s_b, self.v)

            x_ab_i = self.decoder(c_a, s_b_interp)
            x_ba_i = self.decoder(c_b, s_a_interp)

            c_a_recon, s_b_recon = self.encoder(x_ab_i)
            c_b_recon, s_a_recon = self.encoder(x_ba_i)

            x_ab.append(self.decoder(c_a, s_b_interp.unsqueeze(0)))
            x_ba.append(self.decoder(c_b, s_a_interp.unsqueeze(0)))
            x_aa.append(self.decoder(c_a_recon, s_a.unsqueeze(0)))
            x_bb.append(self.decoder(c_b_recon, s_b.unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ab, x_aa = torch.cat(x_ab), torch.cat(x_aa)
        x_ba, x_bb = torch.cat(x_ba), torch.cat(x_bb)

        self.train()

        return x_a, x_a_recon, x_ab, x_aa, x_b, x_b_recon, x_ba, x_bb

    def update_learning_rate(self):
        if self.dis_a_scheduler is not None:
            self.dis_a_scheduler.step()
        if self.dis_b_scheduler is not None:
            self.dis_b_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.enc_scheduler is not None:
            self.enc_scheduler.step()
        if self.dec_scheduler is not None:
            self.dec_scheduler.step()
        if self.interpo_ab_scheduler is not None:
            self.interpo_ab_scheduler.step()
        if self.interpo_ba_scheduler is not None:
            self.interpo_ba_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load encode
        model_name = get_model(checkpoint_dir, "encoder")
        state_dict = torch.load(model_name)
        self.encoder.load_state_dict(state_dict)

        # Load decode
        model_name = get_model(checkpoint_dir, "decoder")
        state_dict = torch.load(model_name)
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_a")
        state_dict = torch.load(model_name)
        self.dis_a.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_b")
        state_dict = torch.load(model_name)
        self.dis_b.load_state_dict(state_dict)

        # Load interperlator ab
        model_name = get_model(checkpoint_dir, "interp_ab")
        state_dict = torch.load(model_name)
        self.interp_net_ab.load_state_dict(state_dict)

        # Load interperlator ba
        model_name = get_model(checkpoint_dir, "interp_ba")
        state_dict = torch.load(model_name)
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def resume_iter(self, checkpoint_dir, surfix, hyperparameters):
        # Load encode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'encoder' + surfix + '.pt'))
        self.encoder.load_state_dict(state_dict)

        # Load decode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'decoder' + surfix + '.pt'))
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_a' + surfix + '.pt'))
        self.dis_a.load_state_dict(state_dict)

        # # Load discriminator b
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_b' + surfix + '.pt'))
        self.dis_b.load_state_dict(state_dict)

        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp' + surfix + '.pt'))
        # print(state_dict)
        self.interp_net_ab.load_state_dict(state_dict['ab'])
        self.interp_net_ba.load_state_dict(state_dict['ba'])

        # Load interperlator ab
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ab' + surfix + '.pt'))
        self.interp_net_ab.load_state_dict(state_dict)

        # # Load interperlator ba
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ba' + surfix + '.pt'))
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'optimizer' + surfix + '.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def save_better_model(self, snapshot_dir):
        # remove sub_optimal models
        files = glob.glob(snapshot_dir + '/*')
        for f in files:
            os.remove(f)
        # Save encoder, decoder, interpolator, discriminators, and optimizers
        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%.4f.pt' % (self.total_loss))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%.4f.pt' % (self.total_loss))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%.4f.pt' % (self.total_loss))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%.4f.pt' % (self.total_loss))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%.4f.pt' % (self.total_loss))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%.4f.pt' % (self.total_loss))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)

    def save_at_iter(self, snapshot_dir, iterations):

        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%08d.pt' % (iterations + 1))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%08d.pt' % (iterations + 1))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%08d.pt' % (iterations + 1))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%08d.pt' % (iterations + 1))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%08d.pt' % (iterations + 1))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir,
                                'optimizer_%08d.pt' % (iterations + 1))

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)
Exemple #2
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size,
                          FLAGS.image_size)
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        X = X.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(mnist) / FLAGS.batch_size)):
            # load a mini-batch
            image_batch, labels_batch = next(loader)

            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X.copy_(image_batch)

            style_mu, style_logvar, class_mu, class_logvar = encoder(
                Variable(X))
            grouped_mu, grouped_logvar = accumulate_group_evidence(
                class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

            # kl-divergence error for style latent space
            style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                 style_logvar.exp()))
            style_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            style_kl_divergence_loss.backward(retain_graph=True)

            # kl-divergence error for class latent space
            class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                 grouped_logvar.exp()))
            class_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            class_kl_divergence_loss.backward(retain_graph=True)

            # reconstruct samples
            """
            sampling from group mu and logvar for each image in mini-batch differently makes
            the decoder consider class latent embeddings as random noise and ignore them 
            """
            style_latent_embeddings = reparameterize(training=True,
                                                     mu=style_mu,
                                                     logvar=style_logvar)
            class_latent_embeddings = group_wise_reparameterize(
                training=True,
                mu=grouped_mu,
                logvar=grouped_logvar,
                labels_batch=labels_batch,
                cuda=FLAGS.cuda)

            reconstructed_images = decoder(style_latent_embeddings,
                                           class_latent_embeddings)

            reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_images, Variable(X))
            reconstruction_error.backward()

            auto_encoder_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Style KL-Divergence loss: ' +
                      str(style_kl_divergence_loss.data.storage().tolist()[0]))
                print('Class KL-Divergence loss: ' +
                      str(class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    style_kl_divergence_loss.data.storage().tolist()[0],
                    class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Style KL-Divergence loss',
                style_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Class KL-Divergence loss',
                class_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)

        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
Exemple #3
0
class VPAF(torch.nn.Module):

    def __init__(self, input_type='image', representation_type='image', output_type=['image'], s_type='classes', input_dim=104, \
            representation_dim=8, output_dim=[1], s_dim=1, problem='privacy', beta=1.0, gamma=1.0, prior_type='Gaussian'):
        super(VPAF, self).__init__()

        self.problem = problem
        self.param = gamma if self.problem == 'privacy' else beta
        self.input_type = input_type
        self.representation_type = representation_type
        self.output_type = output_type
        self.output_dim = output_dim
        self.s_type = s_type
        self.prior_type = prior_type

        self.encoder = Encoder(input_type, representation_type, input_dim,
                               representation_dim)
        self.decoder = Decoder(representation_type,
                               output_type,
                               representation_dim,
                               output_dim,
                               s_dim=s_dim)

    def get_IXY_ub(self, y_mean, mode='Gaussian'):

        if mode == 'Gaussian':
            Dkl = -0.5 * torch.sum(1.0 + self.encoder.y_logvar_theta -
                                   torch.pow(y_mean, 2) -
                                   torch.exp(self.encoder.y_logvar_theta))
            IXY_ub = Dkl / math.log(2)  # in bits
        else:  # MoG
            IXY_ub = KDE_IXY_estimation(self.encoder.y_logvar_theta, y_mean)
            IXY_ub /= math.log(2)  # in bits

        return IXY_ub

    def get_H_output_given_SY_ub(self, decoder_output, t):

        if len(t.shape) == 1:
            t = t.view(-1, 1)

        H_output_given_SY_ub = 0
        dim_start_out = 0
        dim_start_t = 0
        reg_start = 0

        for output_type_, output_dim_ in zip(self.output_type,
                                             self.output_dim):

            if output_type_ == 'classes':
                so = dim_start_out
                eo = dim_start_out + output_dim_
                st = dim_start_t
                et = dim_start_t + 1
                CE = torch.nn.functional.cross_entropy(
                    decoder_output[:, so:eo],
                    t[:, st:et].long().view(-1),
                    reduction='sum')
            elif output_type_ == 'binary':
                so = dim_start_out
                eo = dim_start_out + 1
                st = dim_start_t
                et = dim_start_t + 1
                CE = torch.nn.functional.binary_cross_entropy_with_logits(
                    decoder_output[:, so:eo].view(-1),
                    t[:, st:et].view(-1),
                    reduction='sum')
            elif output_type_ == 'image':
                eo = et = 0
                CE = torch.nn.functional.binary_cross_entropy(decoder_output,
                                                              t,
                                                              reduction='sum')
            else:  # regression
                so = dim_start_out
                eo = dim_start_out + output_dim_
                st = dim_start_t
                et = dim_start_t + output_dim_
                sr = reg_start
                er = reg_start + output_dim_
                reg_start = er
                CE = 0.5 * torch.sum(
                    math.log(2*math.pi) + self.decoder.out_logvar_phi[sr:er] + \
                        torch.pow(decoder_output[:,so:eo] - t[:,st:et], 2) / (torch.exp(self.decoder.out_logvar_phi[sr:er]) + 1e-10)
                )

            H_output_given_SY_ub += CE / math.log(2)  # in bits

            dim_start_out = eo
            dim_start_t = et

        return H_output_given_SY_ub

    def evaluate_privacy(self, dataloader, device, N, batch_size, figs_dir,
                         verbose):

        IXY_ub = 0
        H_X_given_SY_ub = 0

        with torch.no_grad():
            for it, (x, t, s) in enumerate(dataloader):

                x = x.to(device).float()
                t = t.to(device).float()
                s = s.to(device).float()

                y, y_mean = self.encoder(x)
                output = self.decoder(y, s)

                if self.input_type == 'image' and self.representation_type == 'image' and 'image' in self.output_type and it == 1:
                    torchvision.utils.save_image(x[:12 * 8],
                                                 os.path.join(
                                                     figs_dir, 'x.eps'),
                                                 nrow=12)
                    torchvision.utils.save_image(y_mean[:12 * 8],
                                                 os.path.join(
                                                     figs_dir, 'y.eps'),
                                                 nrow=12)
                    torchvision.utils.save_image(output[:12 * 8],
                                                 os.path.join(
                                                     figs_dir, 'x_hat.eps'),
                                                 nrow=12)

                IXY_ub += self.get_IXY_ub(y_mean, self.prior_type)
                H_X_given_SY_ub += self.get_H_output_given_SY_ub(output, t)
                if self.representation_type == 'image':
                    if it == 0 and self.s_type == 'classes':
                        reducer_y = umap.UMAP(random_state=0)
                        reducer_y.fit(y_mean.cpu().view(batch_size, -1),
                                      y=s.cpu())
                        reducer_x = umap.UMAP(random_state=0)
                        reducer_x.fit(x.cpu().view(batch_size, -1), y=s.cpu())
                    if it == 1:
                        if self.s_type == 'classes':
                            embedding_s_y = reducer_y.transform(
                                y_mean.cpu().view(batch_size, -1))
                            embedding_s_x = reducer_x.transform(x.cpu().view(
                                batch_size, -1))
                        reducer_y = umap.UMAP(random_state=0)
                        reducer_x = umap.UMAP(random_state=0)
                        embedding_y = reducer_y.fit_transform(
                            y_mean.cpu().view(batch_size, -1))
                        embedding_x = reducer_x.fit_transform(x.cpu().view(
                            batch_size, -1))
                        if self.s_type == 'classes':
                            plot_embeddings(embedding_y, embedding_s_y,
                                            s.cpu().view(batch_size).long(),
                                            figs_dir, 'y')
                            plot_embeddings(embedding_x, embedding_s_x,
                                            s.cpu().view(batch_size).long(),
                                            figs_dir, 'x')
                        else:
                            plot_embeddings(embedding_y, embedding_y, -1,
                                            figs_dir, 'y')
                            plot_embeddings(embedding_y, embedding_y, -1,
                                            figs_dir, 'x')

        IXY_ub /= N
        H_X_given_SY_ub /= N
        print(f'IXY: {IXY_ub.item()}') if verbose else 0
        print(f'HX_given_SY: {H_X_given_SY_ub.item()}') if verbose else 0
        return IXY_ub, H_X_given_SY_ub

    def evaluate_fairness(self, dataloader, device, N, target_vals,
                          H_T_given_S, verbose):

        IXY_ub = 0
        H_T_given_SY_ub = 0
        accuracy = 0

        with torch.no_grad():
            for it, (x, t, s) in enumerate(dataloader):

                x = x.to(device).float()
                t = t.to(device).float()
                s = s.to(device).float()

                y, y_mean = self.encoder(x)
                output = self.decoder(y, s)

                IXY_ub += self.get_IXY_ub(y_mean, self.prior_type)
                H_T_given_SY_ub += self.get_H_output_given_SY_ub(output, t)
                accuracy += metrics.get_accuracy(output, t,
                                                 target_vals) * len(x)

        IXY_ub /= N
        H_T_given_SY_ub /= N
        print(H_T_given_SY_ub)
        accuracy /= N
        IYT_given_S_lb = H_T_given_S - H_T_given_SY_ub.item()
        print(f'I(X;Y) = {IXY_ub.item()}') if verbose else 0
        print(f'I(Y;T|S) = {IYT_given_S_lb}') if verbose else 0
        print(f'Accuracy (network): {accuracy}') if verbose else 0
        return IXY_ub, IYT_given_S_lb

    def evaluate(self, dataset, verbose, figs_dir):

        device = 'cuda' if next(self.encoder.parameters()).is_cuda else 'cpu'
        batch_size = 2048
        if len(dataset) < 2048:
            batch_size = len(dataset)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size,
                                                 shuffle=True)
        if self.problem == 'privacy':
            IXY_ub, H_X_given_SY_ub = self.evaluate_privacy(
                dataloader, device, len(dataset), batch_size, figs_dir,
                verbose)
            return IXY_ub, H_X_given_SY_ub
        else:  # fairness
            H_T_given_S = get_conditional_entropy(dataset.targets,
                                                  dataset.hidden,
                                                  dataset.target_vals,
                                                  dataset.hidden_vals)
            IXY_ub, IYT_given_S_lb = self.evaluate_fairness(
                dataloader, device, len(dataset), dataset.target_vals,
                H_T_given_S, verbose)
            return IXY_ub, IYT_given_S_lb

    def train_step(self, batch_size, learning_rate, dataloader, optimizer,
                   verbose):

        device = 'cuda' if next(self.encoder.parameters()).is_cuda else 'cpu'

        for x, t, s in progressbar(dataloader):

            x = x.to(device).float()
            t = t.to(device).float()
            s = s.to(device).float()

            optimizer.zero_grad()
            y, y_mean = self.encoder(x)
            output = self.decoder(y, s)
            IXY_ub = self.get_IXY_ub(y_mean, self.prior_type)
            H_output_given_SY_ub = self.get_H_output_given_SY_ub(output, t)
            loss = IXY_ub + self.param * H_output_given_SY_ub

            loss.backward()
            optimizer.step()


    def fit(self, dataset_train, dataset_val, epochs=1000, learning_rate=1e-3, batch_size=1024, eval_rate=15, \
        verbose=True, logs_dir='../results/logs/', figs_dir='../results/images/'):

        dataloader = torch.utils.data.DataLoader(dataset_train,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        params = list(self.encoder.parameters()) + list(
            self.decoder.parameters())
        optimizer = torch.optim.Adam(params, lr=learning_rate)

        for epoch in range(epochs):
            print(f'Epoch # {epoch+1}')
            self.train_step(batch_size, learning_rate, dataloader, optimizer,
                            verbose)

            if epoch % eval_rate == eval_rate - 1:
                print(f'Evaluating TRAIN') if verbose else 0
                if self.problem == 'privacy':
                    IXY_ub, H_X_given_SY_ub = self.evaluate(
                        dataset_train, verbose, figs_dir)
                else:  # fairness
                    self.evaluate(dataset_train, verbose, figs_dir)
                    print(f'Evaluating VALIDATION/TEST') if verbose else 0
                    XY_ub, IYT_given_S_lb = self.evaluate(
                        dataset_val, verbose, figs_dir)
generator.load_state_dict(torch.load(opt.G_path))
discriminator.load_state_dict(torch.load(opt.D_path))
generator.to(opt.device)
encoder.to(opt.device)
discriminator.to(opt.device)

encoder.train()
discriminator.train()

dataloader = load_data(opt)

generator.eval()

Tensor = torch.cuda.FloatTensor if  opt.device == 'cuda' else torch.FloatTensor

optimizer_E = torch.optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

max_auc = 0
for epoch in range(opt.n_epochs):

    # train
    for i, (imgs, _) in enumerate(dataloader.train):
        optimizer_E.zero_grad()
        optimizer_D.zero_grad()

        imgs = imgs.to(opt.device)
        generator.zero_grad()
        z = encoder(imgs)

        fake_imgs = generator(z)
Exemple #5
0
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', ENCODER_SAVE)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', DECODER_SAVE)))

    # loss definition
    mse_loss = nn.MSELoss()

    # add option to run on gpu
    if (CUDA):
        encoder.cuda()
        decoder.cuda()
        mse_loss.cuda()

    # optimizer
    optimizer = torch.optim.Adam(list(encoder.parameters()) +
                                 list(decoder.parameters()),
                                 lr=LR,
                                 betas=(BETA1, BETA2))

    if torch.cuda.is_available() and not CUDA:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    # creating directories
    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')
Exemple #6
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 784)
    '''
    run on GPU if GPU is available
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder.to(device=device)
    decoder.to(device=device)
    X = X.to(device=device)
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    dirs = os.listdir(os.path.join(os.getcwd(), 'data'))
    print('Loading double multivariate normal time series data...')
    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            print('Running dataset ', dsname)
            ds = DoubleMulNormal(dsname)
            # ds = experiment3(1000, 50, 3)
            loader = cycle(
                DataLoader(ds,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           drop_last=True))

            # initialize summary writer
            writer = SummaryWriter()

            for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
                print()
                print(
                    'Epoch #' + str(epoch) +
                    '........................................................')

                # the total loss at each epoch after running iterations of batches
                total_loss = 0

                for iteration in range(int(len(ds) / FLAGS.batch_size)):
                    # load a mini-batch
                    image_batch, labels_batch = next(loader)

                    # set zero_grad for the optimizer
                    auto_encoder_optimizer.zero_grad()

                    X.copy_(image_batch)

                    style_mu, style_logvar, class_mu, class_logvar = encoder(
                        Variable(X))
                    grouped_mu, grouped_logvar = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, labels_batch,
                        FLAGS.cuda)

                    # kl-divergence error for style latent space
                    style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                         style_logvar.exp()))
                    style_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    style_kl_divergence_loss.backward(retain_graph=True)

                    # kl-divergence error for class latent space
                    class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 *
                        torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                  grouped_logvar.exp()))
                    class_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    class_kl_divergence_loss.backward(retain_graph=True)

                    # reconstruct samples
                    """
                    sampling from group mu and logvar for each image in mini-batch differently makes
                    the decoder consider class latent embeddings as random noise and ignore them 
                    """
                    style_latent_embeddings = reparameterize(
                        training=True, mu=style_mu, logvar=style_logvar)
                    class_latent_embeddings = group_wise_reparameterize(
                        training=True,
                        mu=grouped_mu,
                        logvar=grouped_logvar,
                        labels_batch=labels_batch,
                        cuda=FLAGS.cuda)

                    reconstructed_images = decoder(style_latent_embeddings,
                                                   class_latent_embeddings)

                    reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                        reconstructed_images, Variable(X))
                    reconstruction_error.backward()

                    total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error

                    auto_encoder_optimizer.step()

                    if (iteration + 1) % 50 == 0:
                        print('\tIteration #' + str(iteration))
                        print('Reconstruction loss: ' + str(
                            reconstruction_error.data.storage().tolist()[0]))
                        print('Style KL loss: ' +
                              str(style_kl_divergence_loss.data.storage().
                                  tolist()[0]))
                        print('Class KL loss: ' +
                              str(class_kl_divergence_loss.data.storage().
                                  tolist()[0]))

                    # write to log
                    with open(FLAGS.log_file, 'a') as log:
                        log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                            epoch, iteration,
                            reconstruction_error.data.storage().tolist()[0],
                            style_kl_divergence_loss.data.storage().tolist()
                            [0],
                            class_kl_divergence_loss.data.storage().tolist()
                            [0]))

                    # write to tensorboard
                    writer.add_scalar(
                        'Reconstruction loss',
                        reconstruction_error.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Style KL-Divergence loss',
                        style_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Class KL-Divergence loss',
                        class_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)

                    if epoch == 0 and (iteration + 1) % 50 == 0:
                        torch.save(
                            encoder.state_dict(),
                            os.path.join('checkpoints', 'encoder_' + dsname))
                        torch.save(
                            decoder.state_dict(),
                            os.path.join('checkpoints', 'decoder_' + dsname))

                # save checkpoints after every 10 epochs
                if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join('checkpoints', 'encoder_' + dsname))
                    torch.save(
                        decoder.state_dict(),
                        os.path.join('checkpoints', 'decoder_' + dsname))

                print('Total loss at current epoch: ', total_loss.item())
Exemple #7
0
class BiGAN(object):
    def __init__(self, args):

        self.z_dim = args.z_dim
        self.decay_rate = args.decay_rate
        self.learning_rate = args.learning_rate
        self.model_name = args.model_name
        self.batch_size = args.batch_size

        #initialize networks
        self.Generator = Generator(self.z_dim).cuda()
        self.Encoder = Encoder(self.z_dim).cuda()
        self.Discriminator = Discriminator().cuda()

        #set optimizers for all networks
        self.optimizer_G_E = torch.optim.Adam(
            list(self.Generator.parameters()) +
            list(self.Encoder.parameters()),
            lr=self.learning_rate,
            betas=(0.5, 0.999))

        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=self.learning_rate,
                                            betas=(0.5, 0.999))

        #initialize network weights
        self.Generator.apply(weights_init)
        self.Encoder.apply(weights_init)
        self.Discriminator.apply(weights_init)

    def train(self, data):

        self.Generator.train()
        self.Encoder.train()
        self.Discriminator.train()

        self.optimizer_G_E.zero_grad()
        self.optimizer_D.zero_grad()

        #get fake z_data for generator
        self.z_fake = torch.randn((self.batch_size, self.z_dim))

        #send fake z_data through generator to get fake x_data
        self.x_fake = self.Generator(self.z_fake.detach())

        #send real data through encoder to get real z_data
        self.z_real = self.Encoder(data)

        #send real x and z data into discriminator
        self.out_real = self.Discriminator(data, z_real.detach())

        #send fake x and z data into discriminator
        self.out_fake = self.Discriminator(x_fake.detach(), z_fake.detach())

        #compute discriminator loss
        self.D_loss = nn.BCELoss()

        #compute generator/encoder loss
        self.G_E_loss = nn.BCELoss()

        #compute discriminator gradiants and backpropogate
        self.D_loss.backward()
        self.optimizer_D.step()

        #compute generator/encoder gradiants and backpropogate
        self.G_E_loss.backward()
        self.optimizer_G_E.step()
Exemple #8
0
class TadGAN(pl.LightningModule):
    def __init__(self,
                 in_size: int,
                 ts_size: int = 100,
                 latent_dim: int = 20,
                 lr: float = 0.0005,
                 weight_decay: float = 1e-6,
                 iterations_critic: int = 5,
                 gamma: float = 10,
                 weighted: bool = True,
                 use_gru=False):
        super(TadGAN, self).__init__()
        self.in_size = in_size
        self.latent_dim = latent_dim
        self.lr = lr
        self.weight_decay = weight_decay
        self.iterations_critic = iterations_critic
        self.gamma = gamma
        self.weighted = weighted

        self.hparams = {
            'lr': self.lr,
            'weight_decay': self.weight_decay,
            'iterations_critic': self.iterations_critic,
            'gamma': self.gamma
        }

        self.encoder = Encoder(in_size,
                               ts_size=ts_size,
                               out_size=self.latent_dim,
                               batch_first=True,
                               use_gru=use_gru)
        self.generator = Generator(use_gru=use_gru)
        self.critic_x = CriticX(in_size=in_size)
        self.critic_z = CriticZ()

        self.encoder.apply(init_weights)
        self.generator.apply(init_weights)
        self.critic_x.apply(init_weights)
        self.critic_z.apply(init_weights)

        if self.logger is not None:
            self.logger.log_hyperparams(self.hparams)

        self.y_hat = []
        self.index = []
        self.critic = []

    def on_fit_start(self):
        if self.logger is not None:
            fig = plot_rws(self.trainer.datamodule.X.cpu().numpy())
            self.logger.experiment.add_figure('Rolling windows/GT',
                                              fig,
                                              global_step=self.global_step)

    def forward(self, x):
        y_hat = self.generator(self.encoder(x))
        critic = self.critic_x(x)

        return y_hat, critic

    def training_step(self, batch, batch_idx, optimizer_idx):
        x = batch[0]
        batch_size = x.size(0)
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        valid = -torch.ones(batch_size, 1, device=self.device)
        fake = torch.ones(batch_size, 1, device=self.device)

        if optimizer_idx == 0:
            if (batch_idx + 1) % self.iterations_critic != 0:
                return None
            z_gen = self.encoder(x)
            x_rec = self.generator(z_gen)
            fake_gen_z = self.critic_z(z_gen)
            fake_gen_x = self.critic_x(self.generator(z))

            wx_loss = self._wasserstein_loss(valid, fake_gen_x)
            wz_loss = self._wasserstein_loss(valid, fake_gen_z)
            rec_loss = F.mse_loss(x_rec, x)
            loss = wx_loss + wz_loss + self.gamma * rec_loss
            vals = {
                'train/Encoder_Generator/loss': loss,
                'train/Encoder_Generator/Wasserstein_x_loss': wx_loss,
                'train/Encoder_Generator/Wasserstein_z_loss': wz_loss,
                'train/Encoder_Generator/Reconstruction_loss': rec_loss
            }
            self.log_dict(vals)
        elif optimizer_idx == 1:
            valid_x = self.critic_x(x)
            x_gen = self.generator(z).detach()
            fake_x = self.critic_x(x_gen)

            wv_loss = self._wasserstein_loss(valid, valid_x)
            wf_loss = self._wasserstein_loss(fake, fake_x)
            gp_loss = self._calculate_gradient_penalty(self.critic_x, x, x_gen)
            loss = wv_loss + wf_loss + self.gamma * gp_loss
            vals = {
                'train/Critic_x/loss': loss,
                'train/Critic_x/Wasserstein_valid_loss': wv_loss,
                'train/Critic_x/Wasserstein_fake_loss': wf_loss,
                'train/Critic_x/gradient_penalty': gp_loss
            }
            self.log_dict(vals)
        elif optimizer_idx == 2:
            valid_z = self.critic_z(z)
            z_gen = self.encoder(x).detach()
            fake_z = self.critic_z(z_gen)

            wv_loss = self._wasserstein_loss(valid, valid_z)
            wf_loss = self._wasserstein_loss(fake, fake_z)
            gp_loss = self._calculate_gradient_penalty(self.critic_z, z, z_gen)
            loss = wv_loss + wf_loss + self.gamma * gp_loss
            vals = {
                'train/Critic_z/loss': loss,
                'train/Critic_z/Wasserstein_valid_loss': wv_loss,
                'train/Critic_z/Wasserstein_fake_loss': wf_loss,
                'train/Critic_z/gradient_penalty': gp_loss
            }
            self.log_dict(vals)
        else:
            raise NotImplementedError()
        return loss

    def validation_step(self, batch, batch_idx):
        x, index = batch
        y_hat, critic = self(x)

        self.y_hat.append(y_hat)
        self.index.append(index)
        self.critic.append(critic)
        return None

    def validation_epoch_end(self, validation_step_outputs):
        if self.logger is None:
            return

        for net_name, net in zip(
            ['Encoder', 'Generator', 'Critic_X', 'Critic_Z'],
            [self.encoder, self.generator, self.critic_x, self.critic_z]):
            for m in net.modules():
                for name, param in m.named_parameters():
                    self.logger.experiment.add_histogram(
                        net_name + '/' + name, param.data)

        y_hat = torch.cat(self.y_hat)
        critic = torch.cat(self.critic)
        index = torch.cat(self.index)

        self.index = []
        self.y_hat = []
        self.critic = []

        n_batches = self.all_gather(y_hat.size(0))
        max_n_batches = n_batches.max()
        if y_hat.size(0) < max_n_batches:
            diff = max_n_batches - y_hat.size(0)
            add_cols = torch.full((diff, *y_hat.shape[1:]),
                                  fill_value=float('nan'),
                                  dtype=y_hat.dtype,
                                  device=y_hat.device)
            y_hat = torch.cat((y_hat, add_cols))
            add_cols = torch.full((diff, *critic.shape[1:]),
                                  fill_value=float('nan'),
                                  dtype=critic.dtype,
                                  device=critic.device)
            critic = torch.cat((critic, add_cols))
            add_cols = torch.full((diff, *index.shape[1:]),
                                  fill_value=-1,
                                  dtype=index.dtype,
                                  device=index.device)
            index = torch.cat((index, add_cols))

        y_hat, critic, index = self.all_gather((y_hat, critic, index))

        if len(y_hat.shape) == 4:
            y_hat = torch.flatten(y_hat, 0, 1)
            critic = torch.flatten(critic, 0, 1)
            index = torch.flatten(index, 0, 1)
        dm = self.trainer.datamodule

        y_shape = y_hat.shape[1:]
        critic_shape = critic.shape[1:]
        index_shape = index.shape[1:]
        mask = ~torch.any(torch.flatten(y_hat, 1, -1).isnan(), dim=1)
        y_hat = y_hat[mask]
        y_hat = y_hat.view(y_hat.size(0), *y_shape)
        critic = critic[mask]
        critic = critic.view(critic.size(0), *critic_shape)
        index = index[mask]

        idx = torch.argsort(index)
        y_hat = y_hat[idx]
        critic = critic[idx]

        assert y_hat.size(0) == critic.size(0)

        max_idx = min(y_hat.shape[0], dm.X.shape[0])

        y_hat = y_hat[:max_idx].cpu().numpy()
        critic = critic = critic[:max_idx].cpu().numpy()
        X = dm.X[:max_idx].cpu().numpy()
        X_index = dm.X_index[:max_idx].cpu().numpy()
        index = dm.index

        self.y_hat = []
        self.critic = []

        # flatten the predicted windows
        # plot the time series
        fig = plot_ts([dm.y, unroll_ts(y_hat)],
                      labels=['original', 'reconstructed'])
        self.logger.experiment.add_figure('TS reconstruction',
                                          fig,
                                          global_step=self.global_step)
        if y_hat.shape[0] == dm.X.shape[0]:
            fig = plot_rws(y_hat)
            self.logger.experiment.add_figure('Rolling windows/Reconstructed',
                                              fig,
                                              global_step=self.global_step)

        errors, true_index, true, predictions = score_anomalies(
            X, y_hat, critic, X_index, rec_error_type='dtw', comb='mult')
        anomalies = find_anomalies(errors,
                                   index,
                                   window_size_portion=0.33,
                                   window_step_size_portion=0.1,
                                   fixed_threshold=True)
        if anomalies.size == 0:
            anomalies = pd.DataFrame(columns=['start', 'end', 'score'])
        else:
            anomalies = pd.DataFrame(anomalies,
                                     columns=['start', 'end', 'score'])

        gt_anomalies = dm.anomalies
        if gt_anomalies is not None:
            fig = plot(dm.df, [('anomalies', anomalies),
                               ('gt_anomalies', gt_anomalies)])
        else:
            fig = plot(dm.df, [('anomalies', anomalies)])
        self.logger.experiment.add_figure('AD output',
                                          fig,
                                          global_step=self.global_step)

        metric_logged = False
        if not anomalies.empty:
            fig = plot_table_anomalies(anomalies)
            self.logger.experiment.add_figure('Anomaly table',
                                              fig,
                                              global_step=self.global_step)

            if gt_anomalies is not None:
                # Workaround to dispay PR Curve
                if self.weighted:
                    labels, preds, weights = contextual_prepare_weighted(
                        gt_anomalies, anomalies, data=dm.df)
                    self.logger.experiment.add_pr_curve(
                        'PR Curve',
                        np.array(labels),
                        np.array(preds),
                        weights=np.array(weights),
                        global_step=self.global_step)

                acc = contextual_accuracy(gt_anomalies,
                                          anomalies,
                                          data=dm.df,
                                          weighted=self.weighted)
                prec = contextual_precision(gt_anomalies,
                                            anomalies,
                                            data=dm.df,
                                            weighted=self.weighted)
                recall = contextual_recall(gt_anomalies,
                                           anomalies,
                                           data=dm.df,
                                           weighted=self.weighted)
                f1 = contextual_f1_score(gt_anomalies,
                                         anomalies,
                                         data=dm.df,
                                         weighted=self.weighted,
                                         beta=2)
                vals = {
                    'Accuracy': acc,
                    'Precision': prec,
                    'Recall': recall,
                    'F1': f1
                }
                self.log_dict(vals)
                metric_logged = True
        if not metric_logged:
            vals = {'Accuracy': 0, 'Precision': 0, 'Recall': 0, 'F1': 0}
            self.log_dict(vals)

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        optimizer.zero_grad(set_to_none=True)

    def configure_optimizers(self):
        params = [self.encoder.parameters(), self.generator.parameters()]
        e_g_opt = Adam(itertools.chain(*params),
                       lr=self.lr,
                       weight_decay=self.weight_decay)
        c_x_opt = Adam(self.critic_x.parameters(),
                       lr=self.lr,
                       weight_decay=self.weight_decay)
        c_z_opt = Adam(self.critic_z.parameters(),
                       lr=self.lr,
                       weight_decay=self.weight_decay)

        return [e_g_opt, c_x_opt, c_z_opt]

    @staticmethod
    def _wasserstein_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
        return torch.mean(y_true * y_pred)

    def _calculate_gradient_penalty(self, model: torch.nn.Module,
                                    y_true: torch.Tensor,
                                    y_pred: torch.Tensor):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake data
        alpha = torch.randn((y_true.size(0), 1, 1), device=self.device)
        # Get random interpolation between real and fake data
        interpolates = (alpha * y_true +
                        ((1 - alpha) * y_pred)).requires_grad_(True)

        model_interpolates = model(interpolates)
        grad_outputs = torch.ones(model_interpolates.size(),
                                  device=self.device,
                                  requires_grad=False)

        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=model_interpolates,
            inputs=interpolates,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1)**2)
        return gradient_penalty
Exemple #9
0
def train(opt):
    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id)
                          if opt.gpu_id >= 0 else torch.device('cpu'))

    #### dataset
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()
    print("The number of training images = %d." % len(data_set))

    #### initialize models
    ## declaration
    E_a2Zb = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Zb2b = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Zb2Za = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_b = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    E_b2Za = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Za2a = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Za2Zb = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_a = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    ## initialization
    E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device)
    G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device)
    T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device)
    D_b = init_net(D_b, init_type=opt.init_type).to(device)

    E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device)
    G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device)
    T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device)
    D_a = init_net(D_a, init_type=opt.init_type).to(device)
    print(
        "+------------------------------------------------------+\nFinish initializing networks."
    )

    #### optimizer and criterion
    ## criterion
    criterionGAN = GANLoss(opt.gan_mode).to(device)
    criterionZId = nn.L1Loss()
    criterionIdt = nn.L1Loss()
    criterionCTC = nn.L1Loss()
    criterionZCyc = nn.L1Loss()

    ## optimizer
    optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(),
                                                   G_Zb2b.parameters(),
                                                   T_Zb2Za.parameters(),
                                                   E_b2Za.parameters(),
                                                   G_Za2a.parameters(),
                                                   T_Za2Zb.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(),
                                                   D_b.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))

    ## scheduler
    scheduler = [
        get_scheduler(optimizer_G, opt),
        get_scheduler(optimizer_D, opt)
    ]

    print(
        "+------------------------------------------------------+\nFinish initializing the optimizers and criterions."
    )

    #### global variables
    checkpoints_pth = os.path.join(opt.checkpoints, opt.name)
    if os.path.exists(checkpoints_pth) is not True:
        os.mkdir(checkpoints_pth)
        os.mkdir(os.path.join(checkpoints_pth, 'images'))
    record_fh = open(os.path.join(checkpoints_pth, 'records.txt'),
                     'w',
                     encoding='utf-8')
    loss_names = [
        'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B',
        'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'
    ]

    fake_A_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images
    fake_B_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images

    print(
        "+------------------------------------------------------+\nFinish preparing the other works."
    )
    print(
        "+------------------------------------------------------+\nNow training is beginning .."
    )
    #### training
    cur_iter = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()  # timer for entire epoch

        for i, data in enumerate(data_set):
            ## setup inputs
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)

            ## forward
            # image cycle / GAN
            latent_B = E_a2Zb(real_A)  #-> a -> Zb     : E_a2b(a)
            fake_B = G_Zb2b(latent_B)  #-> Zb -> b'    : G_b(E_a2b(a))
            latent_A = E_b2Za(real_B)  #-> b -> Za     : E_b2a(b)
            fake_A = G_Za2a(latent_A)  #-> Za -> a'    : G_a(E_b2a(b))

            # Idt
            '''
            rec_A = G_Za2a(E_b2Za(fake_B))          #-> b' -> Za' -> rec_a  : G_a(E_b2a(fake_b))
            rec_B = G_Zb2b(E_a2Zb(fake_A))          #-> a' -> Zb' -> rec_b  : G_b(E_a2b(fake_a))
            '''
            idt_latent_A = E_b2Za(real_A)  #-> a -> Za        : E_b2a(a)
            idt_A = G_Za2a(idt_latent_A)  #-> Za -> idt_a    : G_a(E_b2a(a))
            idt_latent_B = E_a2Zb(real_B)  #-> b -> Zb        : E_a2b(b)
            idt_B = G_Zb2b(idt_latent_B)  #-> Zb -> idt_b    : G_b(E_a2b(b))

            # ZIdt
            T_latent_A = T_Zb2Za(latent_B)  #-> Zb -> Za''  : T_b2a(E_a2b(a))
            T_rec_A = G_Za2a(
                T_latent_A)  #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a)))
            T_latent_B = T_Za2Zb(latent_A)  #-> Za -> Zb''  : T_a2b(E_b2a(b))
            T_rec_B = G_Zb2b(
                T_latent_B)  #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b)))

            # CTC
            T_idt_latent_B = T_Za2Zb(idt_latent_A)  #-> a -> T_a2b(E_b2a(a))
            T_idt_latent_A = T_Zb2Za(idt_latent_B)  #-> b -> T_b2a(E_a2b(b))

            # ZCyc
            TT_latent_B = T_Za2Zb(T_latent_A)  #-> T_a2b(T_b2a(E_a2b(a)))
            TT_latent_A = T_Zb2Za(T_latent_B)  #-> T_b2a(T_a2b(E_b2a(b)))

            ### optimize parameters
            ## Generator updating
            set_requires_grad(
                [D_b, D_a],
                False)  #-> set Discriminator to require no gradient
            optimizer_G.zero_grad()
            # GAN loss
            loss_G_A = criterionGAN(D_b(fake_B), True)
            loss_G_B = criterionGAN(D_a(fake_A), True)
            loss_GAN = loss_G_A + loss_G_B
            # Idt loss
            loss_idt_A = criterionIdt(idt_A, real_A)
            loss_idt_B = criterionIdt(idt_B, real_B)
            loss_Idt = loss_idt_A + loss_idt_B
            # Latent cross-identity loss
            loss_Zid_A = criterionZId(T_rec_A, real_A)
            loss_Zid_B = criterionZId(T_rec_B, real_B)
            loss_Zid = loss_Zid_A + loss_Zid_B
            # Latent cross-translation consistency
            loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A)
            loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B)
            loss_CTC = loss_CTC_B + loss_CTC_A
            # Latent cycle consistency
            loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A)
            loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B)
            loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A

            loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc

            # backward and gradient updating
            loss_G.backward()
            optimizer_G.step()

            ## Discriminator updating
            set_requires_grad([D_b, D_a],
                              True)  # -> set Discriminator to require gradient
            optimizer_D.zero_grad()

            # backward D_b
            fake_B_ = fake_B_pool.query(fake_B)
            #-> real_B, fake_B
            pred_real_B = D_b(real_B)
            loss_D_real_B = criterionGAN(pred_real_B, True)

            pred_fake_B = D_b(fake_B_)
            loss_D_fake_B = criterionGAN(pred_fake_B, False)

            loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            loss_D_B.backward()

            # backward D_a
            fake_A_ = fake_A_pool.query(fake_A)
            #-> real_A, fake_A
            pred_real_A = D_a(real_A)
            loss_D_real_A = criterionGAN(pred_real_A, True)

            pred_fake_A = D_a(fake_A_)
            loss_D_fake_A = criterionGAN(pred_fake_A, False)

            loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            loss_D_A.backward()

            # update the gradients
            optimizer_D.step()

            ### validate here, both qualitively and quantitatively
            ## record the losses
            if cur_iter % opt.log_freq == 0:
                # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B']
                losses = [
                    loss_G_A.item(),
                    loss_D_A.item(),
                    loss_idt_A.item(),
                    loss_CTC_A.item(),
                    loss_Zid_A.item(),
                    loss_ZCyc_A.item(),
                    loss_G_B.item(),
                    loss_D_B.item(),
                    loss_idt_B.item(),
                    loss_CTC_B.item(),
                    loss_Zid_B.item(),
                    loss_ZCyc_B.item()
                ]
                # record
                line = ''
                for loss in losses:
                    line += '{} '.format(loss)
                record_fh.write(line[:-1] + '\n')
                # print out
                print('Epoch: %3d/%3dIter: %9d--------------------------+' %
                      (epoch, opt.epoch, i))
                field_names = loss_names[:len(loss_names) // 2]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[:len(field_names)])
                print(table.get_string(reversesort=True))

                field_names = loss_names[len(loss_names) // 2:]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[-len(field_names):])
                print(table.get_string(reversesort=True))

            ## visualize
            if cur_iter % opt.vis_freq == 0:
                if opt.gpu_id >= 0:
                    real_A = real_A.cpu().data
                    real_B = real_B.cpu().data
                    fake_A = fake_A.cpu().data
                    fake_B = fake_B.cpu().data
                    idt_A = idt_A.cpu().data
                    idt_B = idt_B.cpu().data
                    T_rec_A = T_rec_A.cpu().data
                    T_rec_B = T_rec_B.cpu().data

                plt.subplot(241), plt.title('real_A'), plt.imshow(
                    tensor2image_RGB(real_A[0, ...]))
                plt.subplot(242), plt.title('fake_B'), plt.imshow(
                    tensor2image_RGB(fake_B[0, ...]))
                plt.subplot(243), plt.title('idt_A'), plt.imshow(
                    tensor2image_RGB(idt_A[0, ...]))
                plt.subplot(244), plt.title('L_idt_A'), plt.imshow(
                    tensor2image_RGB(T_rec_A[0, ...]))

                plt.subplot(245), plt.title('real_B'), plt.imshow(
                    tensor2image_RGB(real_B[0, ...]))
                plt.subplot(246), plt.title('fake_A'), plt.imshow(
                    tensor2image_RGB(fake_A[0, ...]))
                plt.subplot(247), plt.title('idt_B'), plt.imshow(
                    tensor2image_RGB(idt_B[0, ...]))
                plt.subplot(248), plt.title('L_idt_B'), plt.imshow(
                    tensor2image_RGB(T_rec_B[0, ...]))

                plt.savefig(
                    os.path.join(checkpoints_pth, 'images',
                                 '%03d_%09d.jpg' % (epoch, i)))

            cur_iter += 1
            #break #-> debug

        ## till now, we finish one epoch, try to update the learning rate
        update_learning_rate(schedulers=scheduler,
                             opt=opt,
                             optimizer=optimizer_D)
        ## save the model
        if epoch % opt.ckp_freq == 0:
            #-> save models
            # torch.save(model.state_dict(), PATH)
            #-> load in models
            # model.load_state_dict(torch.load(PATH))
            # model.eval()
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.cpu()
                G_Zb2b = G_Zb2b.cpu()
                T_Zb2Za = T_Zb2Za.cpu()
                D_b = D_b.cpu()

                E_b2Za = E_b2Za.cpu()
                G_Za2a = G_Za2a.cpu()
                T_Za2Zb = T_Za2Zb.cpu()
                D_a = D_a.cpu()
                '''
                torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
                torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_b.pth' % epoch))
                torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
                torch.save(    D_b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_b.pth' % epoch))

                torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
                torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_a.pth' % epoch))
                torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
                torch.save(    D_a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_a.pth' % epoch))
                '''
            torch.save(
                E_a2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
            torch.save(
                G_Zb2b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch))
            torch.save(
                T_Zb2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
            torch.save(
                D_b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch))

            torch.save(
                E_b2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
            torch.save(
                G_Za2a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch))
            torch.save(
                T_Za2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
            torch.save(
                D_a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch))
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.to(device)
                G_Zb2b = G_Zb2b.to(device)
                T_Zb2Za = T_Zb2Za.to(device)
                D_b = D_b.to(device)

                E_b2Za = E_b2Za.to(device)
                G_Za2a = G_Za2a.to(device)
                T_Za2Zb = T_Za2Zb.to(device)
                D_a = D_a.to(device)
            print("+Successfully saving models in epoch: %3d.-------------+" %
                  epoch)
        #break #-> debug
    record_fh.close()
    print("≧◔◡◔≦ Congratulation! Finishing the training!")
Exemple #10
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.decoder_save)))
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    savedir = 'checkpoints_%d' % (FLAGS.batch_size)
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    # Creating data indices for training and validation splits:
    dataset_size = len(mnist)
    indices = list(range(dataset_size))
    split = 10000
    np.random.seed(0)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_mnist, val_mnist = torch.utils.data.random_split(
        mnist, [dataset_size - split, split])

    # Creating PT data samplers and loaders:
    weights_train = torch.ones(len(mnist))
    weights_test = torch.ones(len(mnist))
    weights_train[val_mnist.indices] = 0
    weights_test[train_mnist.indices] = 0
    counts = torch.zeros(10)
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(i)
        counts[i] = idx_label.sum()
    max = float(counts.max())
    sum_counts = float(counts.sum())
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(
            i).nonzero().squeeze()
        weights_train[train_mnist.indices[idx_label]] = (sum_counts /
                                                         counts[i])

    train_sampler = SubsetRandomSampler(train_mnist.indices)
    valid_sampler = SubsetRandomSampler(val_mnist.indices)
    kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {}
    loader = DataLoader(mnist,
                        batch_size=FLAGS.batch_size,
                        sampler=train_sampler,
                        **kwargs)
    valid_loader = DataLoader(mnist,
                              batch_size=FLAGS.batch_size,
                              sampler=valid_sampler,
                              **kwargs)
    monitor = torch.zeros(FLAGS.end_epoch - FLAGS.start_epoch, 4)
    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )
        elbo_epoch = 0
        term1_epoch = 0
        term2_epoch = 0
        term3_epoch = 0
        for it, (image_batch, labels_batch) in enumerate(loader):
            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X = image_batch.cuda().detach().clone()
            elbo, reconstruction_proba, style_kl_divergence_loss, class_kl_divergence_loss = process(
                FLAGS, X, labels_batch, encoder, decoder)
            (-elbo).backward()
            auto_encoder_optimizer.step()
            elbo_epoch += elbo
            term1_epoch += reconstruction_proba
            term2_epoch += style_kl_divergence_loss
            term3_epoch += class_kl_divergence_loss

        print("Elbo epoch %.2f" % (elbo_epoch / (it + 1)))
        print("Rec. Proba %.2f" % (term1_epoch / (it + 1)))
        print("KL style %.2f" % (term2_epoch / (it + 1)))
        print("KL content %.2f" % (term3_epoch / (it + 1)))
        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            monitor[epoch, :] = eval(FLAGS, valid_loader, encoder, decoder)
            torch.save(
                encoder.state_dict(),
                os.path.join(savedir, FLAGS.encoder_save + '_e%d' % epoch))
            torch.save(
                decoder.state_dict(),
                os.path.join(savedir, FLAGS.decoder_save + '_e%d' % epoch))
            print("VAL elbo %.2f" % (monitor[epoch, 0]))
            print("VAL Rec. Proba %.2f" % (monitor[epoch, 1]))
            print("VAL KL style %.2f" % (monitor[epoch, 2]))
            print("VAL KL content %.2f" % (monitor[epoch, 3]))

            torch.save(monitor, os.path.join(savedir, 'monitor_e%d' % epoch))
Exemple #11
0
    G.apply(init_weights)
    # summary(G, [(256,), (10, 256, 256)], device=device)
    D = Discriminator(n_classes).to(device)
    D.apply(init_weights)
    # summary(D, (13, 256, 256), device=device)
    vgg = VGG().to(device)

    if args.multi_gpu:
        E = nn.DataParallel(E)
        G = nn.DataParallel(G)
        # G = convert_model(G)
        D = nn.DataParallel(D)
        VGG = nn.DataParallel(VGG)

    # Optimizers
    G_opt = optim.Adam(itertools.chain(G.parameters(), E.parameters()),
                       lr=args.lr_G,
                       betas=(args.beta1, args.beta2))
    D_opt = optim.Adam(D.parameters(),
                       lr=args.lr_D,
                       betas=(args.beta1, args.beta2))

    # Load weights from a specific epoch
    start_ep = 0
    if args.load_epoch is not None:
        if args.load_from_experiment is None:
            load_checkpoint_path = checkpoint_path
        else:
            load_checkpoint_path = join('results', args.load_from_experiment,
                                        'checkpoint')
        load_ep = args.load_epoch
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))

    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
        style_latent_space = style_latent_space.cuda()

    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t')
            log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1))
            style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            _, __, class_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(style_1, class_1)
            reconstructed_X_2 = decoder(style_1, class_2)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2
            kl_divergence_error = kl_divergence_loss_1

            auto_encoder_optimizer.step()

            # B. run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, __ = next(loader)
                image_batch_3, _, __ = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

                kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
                kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
                kl_divergence_loss_1.backward(retain_graph=True)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                output_1 = discriminator(Variable(X_3), reconstructed_X_1_3)

                generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels))
                generator_error_1.backward(retain_graph=True)

                style_latent_space.normal_(0., 1.)
                reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3)

                output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3)

                generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels))
                generator_error_2.backward()

                generator_error = generator_error_1 + generator_error_2
                kl_divergence_error += kl_divergence_loss_1

                generator_optimizer.step()

            # C. run the discriminator
            for i in range(FLAGS.discriminator_times):

                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, _, __ = next(loader)
                image_batch_2, image_batch_3, _ = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)
                X_3.copy_(image_batch_3)

                real_output = discriminator(Variable(X_2), Variable(X_3))

                discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                fake_output = discriminator(Variable(X_3), reconstructed_X_1_3)

                discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat((real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()
                                          ).sum().item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy:
                    discriminator_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))

                print('')
                print('Generator loss: ' + str(generator_error.data.storage().tolist()[0]))
                print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format(
                    epoch,
                    iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy
                ))

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100,
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    encoder.apply(weights_init)

    decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))
    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    discriminator_optimizer = optim.Adam(list(discriminator.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    generator_optimizer = optim.Adam(list(encoder.parameters()) +
                                     list(decoder.parameters()),
                                     lr=FLAGS.initial_learning_rate,
                                     betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\t')
            log.write(
                'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_2 = decoder(nv_2, nc_1)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2

            if FLAGS.train_auto_encoder:
                auto_encoder_optimizer.step()

            # B. run the adversarial part of the architecture

            # B. a) run the discriminator
            for i in range(FLAGS.discriminator_times):
                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, image_batch_2, labels_batch_1 = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)

                real_output = discriminator(Variable(X_1), Variable(X_2))

                discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss(
                    real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                image_batch_3, _, labels_batch_3 = next(loader)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                fake_output = discriminator(Variable(X_1), reconstructed_X_3_1)

                discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss(
                    fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(
                    FLAGS.batch_size), torch.zeros(FLAGS.batch_size)),
                                               dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat(
                    (real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(
                    discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data
                                          == target_true_labels.long()).sum(
                                          ).item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator:
                    discriminator_optimizer.step()

            # B. b) run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, labels_batch_1 = next(loader)
                image_batch_3, __, labels_batch_3 = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                output = discriminator(Variable(X_1), reconstructed_X_3_1)

                generator_error = FLAGS.gen_coef * cross_entropy_loss(
                    output, Variable(domain_labels))
                generator_error.backward()

                if FLAGS.train_generator:
                    generator_optimizer.step()

            # print progress after 10 iterations
            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Generator loss: ' +
                      str(generator_error.data.storage().tolist()[0]))

                print('')
                print('Discriminator loss: ' +
                      str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Generator loss',
                generator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Discriminator loss',
                discriminator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(),
                       os.path.join('checkpoints', FLAGS.discriminator_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))
            nv_3, nc_3 = encoder(Variable(X_3))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_3_2 = decoder(nv_3, nc_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            # save cross reconstructed batch
            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Exemple #14
0
def main(args, dataloader):
    # define the networks
    netG = Generator(ngf=args.ngf, nz=args.nz, nc=args.nc).cuda()
    netG.apply(weight_init)
    print(netG)

    netD = Discriminator(ndf=args.ndf, nc=args.nc, nz=args.nz).cuda()
    netD.apply(weight_init)
    print(netD)

    netE = Encoder(nc=args.nc, ngf=args.ngf, nz=args.nz).cuda()
    netE.apply(weight_init)
    print(netE)

    # define the loss criterion
    criterion = nn.BCELoss()

    # define the ground truth labels.
    real_label = 1  # for the real pair
    fake_label = 0  # for the fake pair

    # define the optimizers, one for each network
    netD_optimizer = optim.Adam(netD.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
    netG_optimizer = optim.Adam([{
        'params': netG.parameters()
    }, {
        'params': netE.parameters()
    }],
                                lr=args.lr,
                                betas=(0.5, 0.999))

    # Training loop
    iters = 0

    for epoch in range(args.num_epochs):
        # iterate through the dataloader
        for i, data in enumerate(dataloader, 0):
            real_images = data[0].cuda()
            bs = real_images.shape[0]

            noise1 = torch.Tensor(real_images.size()).normal_(
                0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()
            noise2 = torch.Tensor(real_images.size()).normal_(
                0, 0.1 * (args.num_epochs - epoch) / args.num_epochs).cuda()

            # get the output from the encoder
            z_real = netE(real_images).view(bs, -1)
            mu, sigma = z_real[:, :args.nz], z_real[:, args.nz:]
            log_sigma = torch.exp(sigma)
            epsilon = torch.randn(bs, args.nz).cuda()
            # reparameterization trick
            output_z = mu + epsilon * log_sigma
            output_z = output_z.view(bs, -1, 1, 1)

            # get the output from the generator
            z_fake = torch.randn(bs, args.nz, 1, 1).cuda()
            d_fake = netG(z_fake)

            # get the output from the discriminator for the real pair
            out_real_pair = netD(real_images + noise1, output_z)

            # get the output from the discriminator for the fake pair
            out_fake_pair = netD(d_fake + noise2, z_fake)

            real_labels = torch.full((bs, ), real_label).cuda()
            fake_labels = torch.full((bs, ), fake_label).cuda()

            # compute the losses
            d_loss = criterion(out_real_pair, real_labels) + criterion(
                out_fake_pair, fake_labels)
            g_loss = criterion(out_real_pair, fake_labels) + criterion(
                out_fake_pair, real_labels)

            # update weights
            if g_loss.item() < 3.5:
                netD_optimizer.zero_grad()
                d_loss.backward(retain_graph=True)
                netD_optimizer.step()

            netG_optimizer.zero_grad()
            g_loss.backward()
            netG_optimizer.step()

            # print the training losses
            if iters % 10 == 0:
                print(
                    '[%3d/%d][%3d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x, z): %.4f\tD(G(z), z): %.4f'
                    %
                    (epoch, args.num_epochs, i, len(dataloader), d_loss.item(),
                     g_loss.item(), out_real_pair.mean().item(),
                     out_fake_pair.mean().item()))

            # visualize the samples generated by the G.
            if iters % 500 == 0:
                out_dir = os.path.join(args.log_dir, args.run_name, 'out/')
                os.makedirs(out_dir, exist_ok=True)
                save_image(d_fake.cpu()[:64, ],
                           os.path.join(out_dir,
                                        str(iters).zfill(7) + '.png'),
                           nrow=8,
                           normalize=True)
                # save reconstructions
                recons_dir = os.path.join(args.log_dir, args.run_name,
                                          'recons/')
                os.makedirs(recons_dir, exist_ok=True)
                save_image(torch.cat(
                    [real_images.cpu()[:8],
                     d_fake.cpu()[:8, ]], dim=3),
                           os.path.join(recons_dir,
                                        str(iters).zfill(7) + '.png'),
                           nrow=1,
                           normalize=True)

            iters += 1

        # save weights
        save_dir = os.path.join(args.log_dir, args.run_name, 'weights')
        os.makedirs(save_dir, exist_ok=True)
        save_weights(netG, './%s/netG.pth' % (save_dir))
        save_weights(netE, './%s/netE.pth' % (save_dir))