Пример #1
0
def main(FLAGS):
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    device = 'cuda:0'

    decoder.to(device)
    encoder.to(device)

    tsne = TSNE(2)

    mnist = DataLoader(
        datasets.MNIST(root='mnist',
                       download=True,
                       train=False,
                       transform=transform_config))
    s_dict = {}
    with torch.no_grad():
        for i, (image, label) in enumerate(mnist):
            label = int(label)
            print(i, label)
            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                image.to(device))
            s_dict.setdefault(label, []).append(class_latent_space_1)

    s_all = []
    for label in range(10):
        s_all.extend(s_dict[label])

    s_all = torch.cat(s_all)
    s_all = s_all.view(s_all.shape[0], -1).cpu()

    s_2d = tsne.fit_transform(s_all)

    np.savez('s_2d.npz', s_2d=s_2d)
Пример #2
0
class LSGANs_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(LSGANs_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.encoder = Encoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.decoder = Decoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.dis_a = Discriminator()
        self.dis_b = Discriminator()
        self.interp_net_ab = Interpolator()
        self.interp_net_ba = Interpolator()
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        enc_params = list(self.encoder.parameters())
        dec_params = list(self.decoder.parameters())
        dis_a_params = list(self.dis_a.parameters())
        dis_b_params = list(self.dis_b.parameters())
        interperlator_ab_params = list(self.interp_net_ab.parameters())
        interperlator_ba_params = list(self.interp_net_ba.parameters())

        self.enc_opt = torch.optim.Adam(
            [p for p in enc_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dec_opt = torch.optim.Adam(
            [p for p in dec_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_a_opt = torch.optim.Adam(
            [p for p in dis_a_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_b_opt = torch.optim.Adam(
            [p for p in dis_b_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ab_opt = torch.optim.Adam(
            [p for p in interperlator_ab_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ba_opt = torch.optim.Adam(
            [p for p in interperlator_ba_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters)
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters)
        self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                 hyperparameters)
        self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                 hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.total_loss = 0
        self.best_iter = 0
        self.perceptural_loss = Perceptural_loss()

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()

        c_a, s_a_fake = self.encoder(x_a)
        c_b, s_b_fake = self.encoder(x_b)

        # decode (cross domain)
        s_ab_interp = self.interp_net_ab(s_a_fake, s_b_fake, self.v)
        s_ba_interp = self.interp_net_ba(s_b_fake, s_a_fake, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = selfdecoder(c_a, s_b_interp)
        self.train()
        return x_ab, x_ba

    def zero_grad(self):
        self.dis_a_opt.zero_grad()
        self.dis_b_opt.zero_grad()
        self.dec_opt.zero_grad()
        self.enc_opt.zero_grad()
        self.interp_ab_opt.zero_grad()
        self.interp_ba_opt.zero_grad()

    def dis_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        x_a_feature = self.dis_a(x_a)
        x_ba_feature = self.dis_a(x_ba)
        x_b_feature = self.dis_b(x_b)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_dis_a = (x_ba_feature - x_a_feature).mean()
        self.loss_dis_b = (x_ab_feature - x_b_feature).mean()

        # gradient penality
        self.loss_dis_a_gp = self.dis_a.calculate_gradient_penalty(x_ba, x_a)
        self.loss_dis_b_gp = self.dis_b.calculate_gradient_penalty(x_ab, x_b)


        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b + \
                              hyperparameters['gan_w'] * self.loss_dis_a_gp + \
                              hyperparameters['gan_w'] * self.loss_dis_b_gp

        self.loss_dis_total.backward()
        self.total_loss += self.loss_dis_total.item()
        self.dis_a_opt.step()
        self.dis_b_opt.step()

    def gen_update(self, x_a, x_b, hyperparameters):
        self.zero_grad()

        # encode
        c_a, s_a = self.encoder(x_a)
        c_b, s_b = self.encoder(x_b)

        # decode (within domain)
        x_a_recon = self.decoder(c_a, s_a)
        x_b_recon = self.decoder(c_b, s_b)

        # decode (cross domain)
        self.v = torch.ones(s_a.size())
        s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
        s_b_interp = self.interp_net_ab(s_a, s_b, self.v)
        x_ba = self.decoder(c_b, s_a_interp)
        x_ab = self.decoder(c_a, s_b_interp)

        # encode again
        c_b_recon, s_a_recon = self.encoder(x_ba)
        c_a_recon, s_b_recon = self.encoder(x_ab)

        # decode again
        x_aa = self.decoder(
            c_a_recon, s_a) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bb = self.decoder(
            c_b_recon, s_b) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aa, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bb, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # perceptual loss
        self.loss_gen_vgg_a = self.perceptural_loss(
            x_a_recon, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.perceptural_loss(
            x_b_recon, x_b) if hyperparameters['vgg_w'] > 0 else 0

        self.loss_gen_vgg_aa = self.perceptural_loss(
            x_aa, x_a) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_bb = self.perceptural_loss(
            x_bb, x_b) if hyperparameters['vgg_w'] > 0 else 0

        # GAN loss
        x_ba_feature = self.dis_a(x_ba)
        x_ab_feature = self.dis_b(x_ab)
        self.loss_gen_adv_a = -x_ba_feature.mean()
        self.loss_gen_adv_b = -x_ab_feature.mean()

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_aa + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_bb + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        self.loss_gen_total.backward()
        self.total_loss += self.loss_gen_total.item()
        self.dec_opt.step()
        self.enc_opt.step()
        self.interp_ab_opt.step()
        self.interp_ba_opt.step()

    def sample(self, x_a, x_b):
        self.eval()
        x_a_recon, x_b_recon, x_ab, x_ba, x_aa, x_bb = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a = self.encoder(x_a[i].unsqueeze(0))
            c_b, s_b = self.encoder(x_b[i].unsqueeze(0))
            x_a_recon.append(self.decoder(c_a, s_a))
            x_b_recon.append(self.decoder(c_b, s_b))

            self.v = torch.ones(s_a.size())
            s_a_interp = self.interp_net_ba(s_b, s_a, self.v)
            s_b_interp = self.interp_net_ab(s_a, s_b, self.v)

            x_ab_i = self.decoder(c_a, s_b_interp)
            x_ba_i = self.decoder(c_b, s_a_interp)

            c_a_recon, s_b_recon = self.encoder(x_ab_i)
            c_b_recon, s_a_recon = self.encoder(x_ba_i)

            x_ab.append(self.decoder(c_a, s_b_interp.unsqueeze(0)))
            x_ba.append(self.decoder(c_b, s_a_interp.unsqueeze(0)))
            x_aa.append(self.decoder(c_a_recon, s_a.unsqueeze(0)))
            x_bb.append(self.decoder(c_b_recon, s_b.unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ab, x_aa = torch.cat(x_ab), torch.cat(x_aa)
        x_ba, x_bb = torch.cat(x_ba), torch.cat(x_bb)

        self.train()

        return x_a, x_a_recon, x_ab, x_aa, x_b, x_b_recon, x_ba, x_bb

    def update_learning_rate(self):
        if self.dis_a_scheduler is not None:
            self.dis_a_scheduler.step()
        if self.dis_b_scheduler is not None:
            self.dis_b_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.enc_scheduler is not None:
            self.enc_scheduler.step()
        if self.dec_scheduler is not None:
            self.dec_scheduler.step()
        if self.interpo_ab_scheduler is not None:
            self.interpo_ab_scheduler.step()
        if self.interpo_ba_scheduler is not None:
            self.interpo_ba_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load encode
        model_name = get_model(checkpoint_dir, "encoder")
        state_dict = torch.load(model_name)
        self.encoder.load_state_dict(state_dict)

        # Load decode
        model_name = get_model(checkpoint_dir, "decoder")
        state_dict = torch.load(model_name)
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_a")
        state_dict = torch.load(model_name)
        self.dis_a.load_state_dict(state_dict)

        # Load discriminator a
        model_name = get_model(checkpoint_dir, "dis_b")
        state_dict = torch.load(model_name)
        self.dis_b.load_state_dict(state_dict)

        # Load interperlator ab
        model_name = get_model(checkpoint_dir, "interp_ab")
        state_dict = torch.load(model_name)
        self.interp_net_ab.load_state_dict(state_dict)

        # Load interperlator ba
        model_name = get_model(checkpoint_dir, "interp_ba")
        state_dict = torch.load(model_name)
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def resume_iter(self, checkpoint_dir, surfix, hyperparameters):
        # Load encode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'encoder' + surfix + '.pt'))
        self.encoder.load_state_dict(state_dict)

        # Load decode
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'decoder' + surfix + '.pt'))
        self.decoder.load_state_dict(state_dict)

        # Load discriminator a
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_a' + surfix + '.pt'))
        self.dis_a.load_state_dict(state_dict)

        # # Load discriminator b
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'dis_b' + surfix + '.pt'))
        self.dis_b.load_state_dict(state_dict)

        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp' + surfix + '.pt'))
        # print(state_dict)
        self.interp_net_ab.load_state_dict(state_dict['ab'])
        self.interp_net_ba.load_state_dict(state_dict['ba'])

        # Load interperlator ab
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ab' + surfix + '.pt'))
        self.interp_net_ab.load_state_dict(state_dict)

        # # Load interperlator ba
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'interp_ba' + surfix + '.pt'))
        self.interp_net_ba.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(
            os.path.join(checkpoint_dir, 'optimizer' + surfix + '.pt'))
        self.enc_opt.load_state_dict(state_dict['enc_opt'])
        self.dec_opt.load_state_dict(state_dict['dec_opt'])
        self.dis_a_opt.load_state_dict(state_dict['dis_a_opt'])
        self.dis_b_opt.load_state_dict(state_dict['dis_b_opt'])
        self.interp_ab_opt.load_state_dict(state_dict['interp_ab_opt'])
        self.interp_ba_opt.load_state_dict(state_dict['interp_ba_opt'])

        self.best_iter = state_dict['best_iter']
        self.total_loss = state_dict['total_loss']

        # Reinitilize schedulers
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters,
                                             self.best_iter)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters,
                                             self.best_iter)
        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters,
                                           self.best_iter)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters,
                                           self.best_iter)
        self.interpo_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        self.interpo_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                  hyperparameters,
                                                  self.best_iter)
        print('Resume from iteration %d' % self.best_iter)
        return self.best_iter, self.total_loss

    def save_better_model(self, snapshot_dir):
        # remove sub_optimal models
        files = glob.glob(snapshot_dir + '/*')
        for f in files:
            os.remove(f)
        # Save encoder, decoder, interpolator, discriminators, and optimizers
        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%.4f.pt' % (self.total_loss))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%.4f.pt' % (self.total_loss))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%.4f.pt' % (self.total_loss))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%.4f.pt' % (self.total_loss))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%.4f.pt' % (self.total_loss))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%.4f.pt' % (self.total_loss))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)

    def save_at_iter(self, snapshot_dir, iterations):

        encoder_name = os.path.join(snapshot_dir,
                                    'encoder_%08d.pt' % (iterations + 1))
        decoder_name = os.path.join(snapshot_dir,
                                    'decoder_%08d.pt' % (iterations + 1))
        interp_ab_name = os.path.join(snapshot_dir,
                                      'interp_ab_%08d.pt' % (iterations + 1))
        interp_ba_name = os.path.join(snapshot_dir,
                                      'interp_ba_%08d.pt' % (iterations + 1))
        dis_a_name = os.path.join(snapshot_dir,
                                  'dis_a_%08d.pt' % (iterations + 1))
        dis_b_name = os.path.join(snapshot_dir,
                                  'dis_b_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir,
                                'optimizer_%08d.pt' % (iterations + 1))

        torch.save(self.encoder.state_dict(), encoder_name)
        torch.save(self.decoder.state_dict(), decoder_name)
        torch.save(self.interp_net_ab.state_dict(), interp_ab_name)
        torch.save(self.interp_net_ba.state_dict(), interp_ba_name)
        torch.save(self.dis_a_opt.state_dict(), dis_a_name)
        torch.save(self.dis_b_opt.state_dict(), dis_b_name)
        torch.save(
            {
                'enc_opt': self.enc_opt.state_dict(),
                'dec_opt': self.dec_opt.state_dict(),
                'dis_a_opt': self.dis_a_opt.state_dict(),
                'dis_b_opt': self.dis_b_opt.state_dict(),
                'interp_ab_opt': self.interp_ab_opt.state_dict(),
                'interp_ba_opt': self.interp_ba_opt.state_dict(),
                'best_iter': self.best_iter,
                'total_loss': self.total_loss
            }, opt_name)
Пример #3
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size,
                          FLAGS.image_size)
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        X = X.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(mnist) / FLAGS.batch_size)):
            # load a mini-batch
            image_batch, labels_batch = next(loader)

            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X.copy_(image_batch)

            style_mu, style_logvar, class_mu, class_logvar = encoder(
                Variable(X))
            grouped_mu, grouped_logvar = accumulate_group_evidence(
                class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

            # kl-divergence error for style latent space
            style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                 style_logvar.exp()))
            style_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            style_kl_divergence_loss.backward(retain_graph=True)

            # kl-divergence error for class latent space
            class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                 grouped_logvar.exp()))
            class_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            class_kl_divergence_loss.backward(retain_graph=True)

            # reconstruct samples
            """
            sampling from group mu and logvar for each image in mini-batch differently makes
            the decoder consider class latent embeddings as random noise and ignore them 
            """
            style_latent_embeddings = reparameterize(training=True,
                                                     mu=style_mu,
                                                     logvar=style_logvar)
            class_latent_embeddings = group_wise_reparameterize(
                training=True,
                mu=grouped_mu,
                logvar=grouped_logvar,
                labels_batch=labels_batch,
                cuda=FLAGS.cuda)

            reconstructed_images = decoder(style_latent_embeddings,
                                           class_latent_embeddings)

            reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_images, Variable(X))
            reconstruction_error.backward()

            auto_encoder_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Style KL-Divergence loss: ' +
                      str(style_kl_divergence_loss.data.storage().tolist()[0]))
                print('Class KL-Divergence loss: ' +
                      str(class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    style_kl_divergence_loss.data.storage().tolist()[0],
                    class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Style KL-Divergence loss',
                style_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Class KL-Divergence loss',
                class_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)

        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
Пример #4
0
    loss = torch.sum(l1 + l2 + torch.log(det_p) - torch.log(det_q), dim=1)
    return loss


if (__name__ == '__main__'):

    # model definition
    encoder = Encoder()
    encoder.apply(weights_init)

    decoder = Decoder()
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if LOAD_SAVED:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', ENCODER_SAVE)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', DECODER_SAVE)))

    # loss definition
    mse_loss = nn.MSELoss()

    # add option to run on gpu
    if (CUDA):
        encoder.cuda()
        decoder.cuda()
        mse_loss.cuda()

    # optimizer
    optimizer = torch.optim.Adam(list(encoder.parameters()) +
                                 list(decoder.parameters()),
Пример #5
0
    likelihood = torch.sum(summand) / summand.size(0)
    



FLAGS = parser.parse_args()

if __name__ == '__main__':
    """
    model definitions
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)

    encoder.load_state_dict(
        torch.load(os.path.join('checkpoints', FLAGS.encoder_save), map_location=lambda storage, loc: storage))
    decoder.load_state_dict(
        torch.load(os.path.join('checkpoints', FLAGS.decoder_save), map_location=lambda storage, loc: storage))

    encoder.cuda()
    decoder.cuda()

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load data set and create data loader instance
    '''
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist', download=True, train=False, transform=transform_config)
    loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))
    image_array = []
Пример #6
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 784)
    '''
    run on GPU if GPU is available
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder.to(device=device)
    decoder.to(device=device)
    X = X.to(device=device)
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    dirs = os.listdir(os.path.join(os.getcwd(), 'data'))
    print('Loading double multivariate normal time series data...')
    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            print('Running dataset ', dsname)
            ds = DoubleMulNormal(dsname)
            # ds = experiment3(1000, 50, 3)
            loader = cycle(
                DataLoader(ds,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           drop_last=True))

            # initialize summary writer
            writer = SummaryWriter()

            for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
                print()
                print(
                    'Epoch #' + str(epoch) +
                    '........................................................')

                # the total loss at each epoch after running iterations of batches
                total_loss = 0

                for iteration in range(int(len(ds) / FLAGS.batch_size)):
                    # load a mini-batch
                    image_batch, labels_batch = next(loader)

                    # set zero_grad for the optimizer
                    auto_encoder_optimizer.zero_grad()

                    X.copy_(image_batch)

                    style_mu, style_logvar, class_mu, class_logvar = encoder(
                        Variable(X))
                    grouped_mu, grouped_logvar = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, labels_batch,
                        FLAGS.cuda)

                    # kl-divergence error for style latent space
                    style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                         style_logvar.exp()))
                    style_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    style_kl_divergence_loss.backward(retain_graph=True)

                    # kl-divergence error for class latent space
                    class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 *
                        torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                  grouped_logvar.exp()))
                    class_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    class_kl_divergence_loss.backward(retain_graph=True)

                    # reconstruct samples
                    """
                    sampling from group mu and logvar for each image in mini-batch differently makes
                    the decoder consider class latent embeddings as random noise and ignore them 
                    """
                    style_latent_embeddings = reparameterize(
                        training=True, mu=style_mu, logvar=style_logvar)
                    class_latent_embeddings = group_wise_reparameterize(
                        training=True,
                        mu=grouped_mu,
                        logvar=grouped_logvar,
                        labels_batch=labels_batch,
                        cuda=FLAGS.cuda)

                    reconstructed_images = decoder(style_latent_embeddings,
                                                   class_latent_embeddings)

                    reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                        reconstructed_images, Variable(X))
                    reconstruction_error.backward()

                    total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error

                    auto_encoder_optimizer.step()

                    if (iteration + 1) % 50 == 0:
                        print('\tIteration #' + str(iteration))
                        print('Reconstruction loss: ' + str(
                            reconstruction_error.data.storage().tolist()[0]))
                        print('Style KL loss: ' +
                              str(style_kl_divergence_loss.data.storage().
                                  tolist()[0]))
                        print('Class KL loss: ' +
                              str(class_kl_divergence_loss.data.storage().
                                  tolist()[0]))

                    # write to log
                    with open(FLAGS.log_file, 'a') as log:
                        log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                            epoch, iteration,
                            reconstruction_error.data.storage().tolist()[0],
                            style_kl_divergence_loss.data.storage().tolist()
                            [0],
                            class_kl_divergence_loss.data.storage().tolist()
                            [0]))

                    # write to tensorboard
                    writer.add_scalar(
                        'Reconstruction loss',
                        reconstruction_error.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Style KL-Divergence loss',
                        style_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Class KL-Divergence loss',
                        class_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)

                    if epoch == 0 and (iteration + 1) % 50 == 0:
                        torch.save(
                            encoder.state_dict(),
                            os.path.join('checkpoints', 'encoder_' + dsname))
                        torch.save(
                            decoder.state_dict(),
                            os.path.join('checkpoints', 'decoder_' + dsname))

                # save checkpoints after every 10 epochs
                if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join('checkpoints', 'encoder_' + dsname))
                    torch.save(
                        decoder.state_dict(),
                        os.path.join('checkpoints', 'decoder_' + dsname))

                print('Total loss at current epoch: ', total_loss.item())
Пример #7
0
def test(opt):
    #### mkdir
    des_pth = os.path.join('results', opt.name)
    if os.path.exists(os.path.join(des_pth)) is not True:
        os.mkdir(des_pth)
    src_pth = os.path.join(opt.checkpoints, opt.name)

    models_name = os.listdir(src_pth)
    models_name.remove('images')
    models_name.remove('records.txt')
    models_name.sort(key=lambda x: int(x[6:9]))
    target = int(models_name[-1][6:9])

    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id) if opt.gpu_id >= 0 else torch.device('cpu'))

    #### data
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()

    #### networks
    ## initialize
    E_a2b = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9)
    G_b = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type)
    E_b2a = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9)
    G_a = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type)

    ## load in models
    E_a2b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_a2b.pth'%target)))
    G_b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_b.pth'%target)))
    E_b2a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_b2a.pth' % target)))
    G_a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_a.pth' % target)))

    E_a2b = E_a2b.to(device)
    G_b = G_b.to(device)
    E_b2a = E_b2a.to(device)
    G_a = G_a.to(device)

    for i, data in enumerate(data_set):
        real_A = data['A'].to(device)
        real_B = data['B'].to(device)

        fake_B = G_b(E_a2b(real_A))
        fake_A = G_a(E_b2a(real_B))

        ## visualize
        if opt.gpu_id >= 0:
            fake_B = fake_B.cpu().data
            fake_A = fake_A.cpu().data

            real_A = real_A.cpu()
            real_B = real_B.cpu()

        for j in range(opt.batch_size):
            fake_b = tensor2image_RGB(fake_B[j, ...])
            fake_a = tensor2image_RGB(fake_A[j, ...])

            real_a = tensor2image_RGB(real_A[j, ...])
            real_b = tensor2image_RGB(real_B[j, ...])

            plt.subplot(221), plt.title("real_A"), plt.imshow(real_a)
            plt.subplot(222), plt.title("fake_B"), plt.imshow(fake_b)
            plt.subplot(223), plt.title("real_B"), plt.imshow(real_b)
            plt.subplot(224), plt.title("fake_A"), plt.imshow(fake_a)

            plt.savefig(os.path.join(des_pth, '%06d-%02d.jpg'%(i, j)))
        #break #-> debug

    print("≧◔◡◔≦ Congratulation! Successfully finishing the testing!")
Пример #8
0
class DoubleDQN:
    def __init__(self, env, tau=0.1, gamma=0.9, epsilon=1.0):
        self.env = env
        self.tau = tau
        self.gamma = gamma
        self.embedding_size = 30
        self.hidden_size = 30
        self.obs_shape = self.env.get_obs().shape
        self.action_shape = 40 // 5
        if args.encoding == "onehot":
            self.encoder = OneHot(
                args.bins,
                self.env.all_questions + self.env.held_out_questions,
                self.hidden_size).to(DEVICE)
        else:
            self.encoder = Encoder(self.embedding_size,
                                   self.hidden_size).to(DEVICE)

        self.model = DQN(self.obs_shape, self.action_shape,
                         self.encoder).to(DEVICE)
        self.target_model = DQN(self.obs_shape, self.action_shape,
                                self.encoder).to(DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.epsilon = epsilon
        if os.path.exists(MODEL_FILE):
            checkpoint = torch.load(MODEL_FILE)
            self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.target_model.load_state_dict(
                checkpoint['target_model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epsilon = checkpoint['epsilon']

        # hard copy model parameters to target model parameters
        for target_param, param in zip(self.model.parameters(),
                                       self.target_model.parameters()):
            target_param.data.copy_(param)

    def get_action(self, state, goal):
        assert len(state.shape
                   ) == 2  # This function should not be called during update

        if (np.random.rand() > self.epsilon):
            q_values = self.model.forward(state, goal)
            idx = torch.argmax(q_values).detach()
            obj_selection = idx // 8
            direction_selection = idx % 8
        else:
            action = self.env.sample_random_action()
            obj_selection = action[0]
            direction_selection = action[1]

        return int(obj_selection), int(direction_selection)

    def compute_loss(self, batch):
        states, actions, goals, rewards, next_states, satisfied_goals, dones = batch

        rewards = torch.FloatTensor(rewards).to(DEVICE)
        dones = torch.FloatTensor(dones).to(DEVICE)

        curr_Q = self.model(states, goals)

        curr_Q_prev_actions = [
            curr_Q[batch, actions[batch][0], actions[batch][1]]
            for batch in range(len(states))
        ]  # TODO: Use pytorch gather
        curr_Q_prev_actions = torch.stack(curr_Q_prev_actions)

        next_Q = self.target_model(next_states, goals)

        next_Q_max_actions = torch.max(next_Q, -1).values
        next_Q_max_actions = torch.max(next_Q_max_actions, -1).values

        next_Q_max_actions = rewards + (
            1 - dones) * self.gamma * next_Q_max_actions

        loss = F.mse_loss(curr_Q_prev_actions, next_Q_max_actions.detach())

        return loss

    def update(self, replay_buffer, batch_size):
        for _ in range(UPDATE_STEPS):
            batch = replay_buffer.sample(batch_size)
            loss = self.compute_loss(batch)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def update_target_net(self):  # TODO: Check this function
        # target network update
        for target_param, param in zip(self.target_model.parameters(),
                                       self.model.parameters()):
            target_param.data.copy_(self.tau * param +
                                    (1 - self.tau) * target_param)

    def save_model(self):
        torch.save(
            {
                'model_state_dict': self.model.state_dict(),
                'target_model_state_dict': self.target_model.state_dict(),
                'encoder_state_dict': self.encoder.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'epsilon': self.epsilon
            }, MODEL_FILE)
Пример #9
0
    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')
    if not os.path.exists('sqerrors'):
        os.makedirs('sqerrors')

    cwd = os.getcwd()
    dirs = os.listdir(os.path.join(cwd, 'data'))
    print('Loading double univariate normal time series test data...')

    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            # load saved parameters of encoder and decoder
            encoder.load_state_dict(
                torch.load(os.path.join(cwd, 'checkpoints',
                                        'encoder_' + dsname),
                           map_location=lambda storage, loc: storage))
            decoder.load_state_dict(
                torch.load(os.path.join(cwd, 'checkpoints',
                                        'decoder_' + dsname),
                           map_location=lambda storage, loc: storage))
            encoder = encoder.to(device=device)
            decoder = decoder.to(device=device)

            paired_mnist = DoubleMulNormal(dsname)
            loader = cycle(
                DataLoader(paired_mnist,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           num_workers=0,
                           drop_last=True))
Пример #10
0
    print(folder)
    checks = np.arange(0, FLAGS.end_epoch + 5, 5)
    checks[1:] -= 1
    monitor[0, 0] = -np.inf
    best_elbo = monitor[checks, 0].argmax()

    #DEBUG
    best_elbo = -1
    print(checks[best_elbo], monitor[checks[best_elbo], 0],
          monitor[checks, 0].max(), best_elbo)
    FLAGS.encoder_save = folder + '/encoder_e%d' % checks[best_elbo]
    FLAGS.decoder_save = folder + '/decoder_e%d' % checks[best_elbo]

    FLAGS.batch_size = 256
    encoder.load_state_dict(
        torch.load(FLAGS.encoder_save,
                   map_location=lambda storage, loc: storage))
    decoder.load_state_dict(
        torch.load(FLAGS.decoder_save,
                   map_location=lambda storage, loc: storage))

    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
Пример #11
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.decoder_save)))
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    savedir = 'checkpoints_%d' % (FLAGS.batch_size)
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    # Creating data indices for training and validation splits:
    dataset_size = len(mnist)
    indices = list(range(dataset_size))
    split = 10000
    np.random.seed(0)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_mnist, val_mnist = torch.utils.data.random_split(
        mnist, [dataset_size - split, split])

    # Creating PT data samplers and loaders:
    weights_train = torch.ones(len(mnist))
    weights_test = torch.ones(len(mnist))
    weights_train[val_mnist.indices] = 0
    weights_test[train_mnist.indices] = 0
    counts = torch.zeros(10)
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(i)
        counts[i] = idx_label.sum()
    max = float(counts.max())
    sum_counts = float(counts.sum())
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(
            i).nonzero().squeeze()
        weights_train[train_mnist.indices[idx_label]] = (sum_counts /
                                                         counts[i])

    train_sampler = SubsetRandomSampler(train_mnist.indices)
    valid_sampler = SubsetRandomSampler(val_mnist.indices)
    kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {}
    loader = DataLoader(mnist,
                        batch_size=FLAGS.batch_size,
                        sampler=train_sampler,
                        **kwargs)
    valid_loader = DataLoader(mnist,
                              batch_size=FLAGS.batch_size,
                              sampler=valid_sampler,
                              **kwargs)
    monitor = torch.zeros(FLAGS.end_epoch - FLAGS.start_epoch, 4)
    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )
        elbo_epoch = 0
        term1_epoch = 0
        term2_epoch = 0
        term3_epoch = 0
        for it, (image_batch, labels_batch) in enumerate(loader):
            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X = image_batch.cuda().detach().clone()
            elbo, reconstruction_proba, style_kl_divergence_loss, class_kl_divergence_loss = process(
                FLAGS, X, labels_batch, encoder, decoder)
            (-elbo).backward()
            auto_encoder_optimizer.step()
            elbo_epoch += elbo
            term1_epoch += reconstruction_proba
            term2_epoch += style_kl_divergence_loss
            term3_epoch += class_kl_divergence_loss

        print("Elbo epoch %.2f" % (elbo_epoch / (it + 1)))
        print("Rec. Proba %.2f" % (term1_epoch / (it + 1)))
        print("KL style %.2f" % (term2_epoch / (it + 1)))
        print("KL content %.2f" % (term3_epoch / (it + 1)))
        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            monitor[epoch, :] = eval(FLAGS, valid_loader, encoder, decoder)
            torch.save(
                encoder.state_dict(),
                os.path.join(savedir, FLAGS.encoder_save + '_e%d' % epoch))
            torch.save(
                decoder.state_dict(),
                os.path.join(savedir, FLAGS.decoder_save + '_e%d' % epoch))
            print("VAL elbo %.2f" % (monitor[epoch, 0]))
            print("VAL Rec. Proba %.2f" % (monitor[epoch, 1]))
            print("VAL KL style %.2f" % (monitor[epoch, 2]))
            print("VAL KL content %.2f" % (monitor[epoch, 3]))

            torch.save(monitor, os.path.join(savedir, 'monitor_e%d' % epoch))
Пример #12
0
    D_opt = optim.Adam(D.parameters(),
                       lr=args.lr_D,
                       betas=(args.beta1, args.beta2))

    # Load weights from a specific epoch
    start_ep = 0
    if args.load_epoch is not None:
        if args.load_from_experiment is None:
            load_checkpoint_path = checkpoint_path
        else:
            load_checkpoint_path = join('results', args.load_from_experiment,
                                        'checkpoint')
        load_ep = args.load_epoch
        start_ep = load_ep + 1
        E.load_state_dict(
            torch.load(
                join(load_checkpoint_path, '{:03}.E.pth'.format(load_ep))))
        G.load_state_dict(
            torch.load(
                join(load_checkpoint_path, '{:03}.G.pth'.format(load_ep))))
        D.load_state_dict(
            torch.load(
                join(load_checkpoint_path, '{:03}.D.pth'.format(load_ep))))
        G_opt.load_state_dict(
            torch.load(
                join(load_checkpoint_path, '{:03}.G_opt.pth'.format(load_ep))))
        D_opt.load_state_dict(
            torch.load(
                join(load_checkpoint_path, '{:03}.D_opt.pth'.format(load_ep))))

    # Criterion
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))

    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
        style_latent_space = style_latent_space.cuda()

    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\t')
            log.write('Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_mnist, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_1 = encoder(Variable(X_1))
            style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            _, __, class_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(style_1, class_1)
            reconstructed_X_2 = decoder(style_1, class_2)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_1))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2
            kl_divergence_error = kl_divergence_loss_1

            auto_encoder_optimizer.step()

            # B. run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, __ = next(loader)
                image_batch_3, _, __ = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

                kl_divergence_loss_1 = - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
                kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
                kl_divergence_loss_1.backward(retain_graph=True)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                output_1 = discriminator(Variable(X_3), reconstructed_X_1_3)

                generator_error_1 = cross_entropy_loss(output_1, Variable(domain_labels))
                generator_error_1.backward(retain_graph=True)

                style_latent_space.normal_(0., 1.)
                reconstructed_X_latent_3 = decoder(Variable(style_latent_space), class_3)

                output_2 = discriminator(Variable(X_3), reconstructed_X_latent_3)

                generator_error_2 = cross_entropy_loss(output_2, Variable(domain_labels))
                generator_error_2.backward()

                generator_error = generator_error_1 + generator_error_2
                kl_divergence_error += kl_divergence_loss_1

                generator_optimizer.step()

            # C. run the discriminator
            for i in range(FLAGS.discriminator_times):

                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, _, __ = next(loader)
                image_batch_2, image_batch_3, _ = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)
                X_3.copy_(image_batch_3)

                real_output = discriminator(Variable(X_2), Variable(X_3))

                discriminator_real_error = cross_entropy_loss(real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
                style_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

                _, __, class_3 = encoder(Variable(X_3))
                reconstructed_X_1_3 = decoder(style_1, class_3)

                fake_output = discriminator(Variable(X_3), reconstructed_X_1_3)

                discriminator_fake_error = cross_entropy_loss(fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(FLAGS.batch_size), torch.zeros(FLAGS.batch_size)), dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat((real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data == target_true_labels.long()
                                          ).sum().item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy:
                    discriminator_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))

                print('')
                print('Generator loss: ' + str(generator_error.data.storage().tolist()[0]))
                print('Discriminator loss: ' + str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format(
                    epoch,
                    iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy
                ))

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Generator loss', generator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator loss', discriminator_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Discriminator accuracy', discriminator_accuracy * 100,
                              epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(), os.path.join('checkpoints', FLAGS.discriminator_save))
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    encoder.apply(weights_init)

    decoder = Decoder(nv_dim=FLAGS.nv_dim, nc_dim=FLAGS.nc_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
        discriminator.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.discriminator_save)))
    """
    variable definition
    """
    real_domain_labels = 1
    fake_domain_labels = 0

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    domain_labels = torch.LongTensor(FLAGS.batch_size)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        domain_labels = domain_labels.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    discriminator_optimizer = optim.Adam(list(discriminator.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    generator_optimizer = optim.Adam(list(encoder.parameters()) +
                                     list(decoder.parameters()),
                                     lr=FLAGS.initial_learning_rate,
                                     betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write('Epoch\tIteration\tReconstruction_loss\t')
            log.write(
                'Generator_loss\tDiscriminator_loss\tDiscriminator_accuracy\n')

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialise variables
    discriminator_accuracy = 0.

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_2 = decoder(nv_2, nc_1)

            reconstruction_error_1 = mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = reconstruction_error_1 + reconstruction_error_2

            if FLAGS.train_auto_encoder:
                auto_encoder_optimizer.step()

            # B. run the adversarial part of the architecture

            # B. a) run the discriminator
            for i in range(FLAGS.discriminator_times):
                discriminator_optimizer.zero_grad()

                # train discriminator on real data
                domain_labels.fill_(real_domain_labels)

                image_batch_1, image_batch_2, labels_batch_1 = next(loader)

                X_1.copy_(image_batch_1)
                X_2.copy_(image_batch_2)

                real_output = discriminator(Variable(X_1), Variable(X_2))

                discriminator_real_error = FLAGS.disc_coef * cross_entropy_loss(
                    real_output, Variable(domain_labels))
                discriminator_real_error.backward()

                # train discriminator on fake data
                domain_labels.fill_(fake_domain_labels)

                image_batch_3, _, labels_batch_3 = next(loader)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                fake_output = discriminator(Variable(X_1), reconstructed_X_3_1)

                discriminator_fake_error = FLAGS.disc_coef * cross_entropy_loss(
                    fake_output, Variable(domain_labels))
                discriminator_fake_error.backward()

                # total discriminator error
                discriminator_error = discriminator_real_error + discriminator_fake_error

                # calculate discriminator accuracy for this step
                target_true_labels = torch.cat((torch.ones(
                    FLAGS.batch_size), torch.zeros(FLAGS.batch_size)),
                                               dim=0)
                if FLAGS.cuda:
                    target_true_labels = target_true_labels.cuda()

                discriminator_predictions = torch.cat(
                    (real_output, fake_output), dim=0)
                _, discriminator_predictions = torch.max(
                    discriminator_predictions, 1)

                discriminator_accuracy = (discriminator_predictions.data
                                          == target_true_labels.long()).sum(
                                          ).item() / (FLAGS.batch_size * 2)

                if discriminator_accuracy < FLAGS.discriminator_limiting_accuracy and FLAGS.train_discriminator:
                    discriminator_optimizer.step()

            # B. b) run the generator
            for i in range(FLAGS.generator_times):

                generator_optimizer.zero_grad()

                image_batch_1, _, labels_batch_1 = next(loader)
                image_batch_3, __, labels_batch_3 = next(loader)

                domain_labels.fill_(real_domain_labels)
                X_1.copy_(image_batch_1)
                X_3.copy_(image_batch_3)

                nv_3, nc_3 = encoder(Variable(X_3))

                # reconstruction is taking common factor from X_1 and varying factor from X_3
                reconstructed_X_3_1 = decoder(nv_3, encoder(Variable(X_1))[1])

                output = discriminator(Variable(X_1), reconstructed_X_3_1)

                generator_error = FLAGS.gen_coef * cross_entropy_loss(
                    output, Variable(domain_labels))
                generator_error.backward()

                if FLAGS.train_generator:
                    generator_optimizer.step()

            # print progress after 10 iterations
            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Generator loss: ' +
                      str(generator_error.data.storage().tolist()[0]))

                print('')
                print('Discriminator loss: ' +
                      str(discriminator_error.data.storage().tolist()[0]))
                print('Discriminator accuracy: ' + str(discriminator_accuracy))

                print('..........')

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    generator_error.data.storage().tolist()[0],
                    discriminator_error.data.storage().tolist()[0],
                    discriminator_accuracy))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Generator loss',
                generator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Discriminator loss',
                discriminator_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            torch.save(discriminator.state_dict(),
                       os.path.join('checkpoints', FLAGS.discriminator_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, labels_batch_1 = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            nv_1, nc_1 = encoder(Variable(X_1))
            nv_2, nc_2 = encoder(Variable(X_2))
            nv_3, nc_3 = encoder(Variable(X_3))

            reconstructed_X_1 = decoder(nv_1, nc_2)
            reconstructed_X_3_2 = decoder(nv_3, nc_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            # save cross reconstructed batch
            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
 E.to(device)
 G = Generator(n_classes)
 G.to(device)
 
 if args.multi_gpu:  # If trained with multi-GPU, the model needs to be loaded with multi-GPU, too.
     E = nn.DataParallel(E)
     G = nn.DataParallel(G)
     # G = convert_model(G)
 
 # Load from checkpoints
 load_epoch = args.test_epoch
 if load_epoch is None:  # Use the lastest model
     load_epoch = max(int(path.split('.')[0]) for path in listdir(checkpoint_path) if path.split('.')[0].isdigit())
 print('Loading generator from epoch {:03d}'.format(load_epoch))
 E.load_state_dict(torch.load(
     join(checkpoint_path, '{:03d}.E.pth'.format(load_epoch)),
     map_location=lambda storage, loc: storage
 ))
 G.load_state_dict(torch.load(
     join(checkpoint_path, '{:03d}.G.pth'.format(load_epoch)),
     map_location=lambda storage, loc: storage
 ))
 
 E.eval()
 G.eval()
 with torch.no_grad():
     for batch_idx, (reals, annos) in enumerate(tqdm(val_data)):
         reals, annos = reals.to(device), annos.to(device)
         annos_onehot = onehot2d(annos, n_classes).type_as(reals)
         
         # Encode images and sample latents
         mu, logvar = E(reals)
Пример #16
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
Пример #17
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()
    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer,
                                                       step_size=80,
                                                       gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(
        reverse_cycle_optimizer, step_size=80, gamma=0.1)
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                Variable(X_1))
            style_latent_space_1 = reparameterize(training=True,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) -
                                 style_logvar_1.exp()))
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(
                Variable(X_2))
            style_latent_space_2 = reparameterize(training=True,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) -
                                 style_logvar_2.exp()))
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1,
                                        class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2,
                                        class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (
                reconstruction_error_1 +
                reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2
                                   ) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space),
                                        class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space),
                                        class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(
                style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' +
                      str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' +
                      str(reverse_cycle_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    reverse_cycle_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'KL-Divergence loss',
                kl_divergence_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Reverse cycle loss',
                reverse_cycle_loss.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, _ = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False,
                                                  mu=style_mu_3,
                                                  logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1,
                                          class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3,
                                          class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Пример #18
0
            print('Pre-Trained GloVe Model')
        else:
            print('Pre-Trained Baseline Model')
    else:
        encoder_checkpoint = torch.load(f'./checkpoints/encoder_{model_tag}',
                                        map_location='cpu')
        decoder_checkpoint = torch.load(f'./checkpoints/decoder_{model_tag}',
                                        map_location='cpu')
        if bert_model:
            print('Pre-Trained BERT Model')
        elif glove_model:
            print('Pre-Trained GloVe Model')
        else:
            print('Pre-Trained Baseline Model')

    encoder.load_state_dict(encoder_checkpoint['model_state_dict'])
    decoder_optimizer = torch.optim.Adam(params=decoder.parameters(),
                                         lr=decoder_lr)
    decoder.load_state_dict(decoder_checkpoint['model_state_dict'])
    decoder_optimizer.load_state_dict(
        decoder_checkpoint['optimizer_state_dict'])
else:
    encoder = Encoder().to(device)
    decoder = Decoder(vocab_size=len(vocab),
                      use_glove=glove_model,
                      use_bert=bert_model,
                      device=device,
                      tokenizer=tokenizer,
                      vocab=vocab,
                      bert_model=BertModel,
                      glove_vectors=glove_vectors).to(device)