Esempio n. 1
0
    def __init__(self,
                 encoder,
                 decoder,
                 attr_encoder,
                 attr_decoder,
                 classifier,
                 train_loader,
                 test_loader_unseen,
                 test_loader_seen,
                 criterion,
                 SIGMA=2,
                 lr=1e-3,
                 all_attrs=None,
                 epoch=10000,
                 save_path="/data/xingyu/wae_lle/experiments/",
                 save_every=1,
                 iftest=False,
                 ifsample=False):
        self.encoder = encoder
        self.decoder = decoder
        self.attr_encoder = attr_encoder
        self.attr_decoder = attr_decoder
        self.classifier = classifier
        self.train_loader = train_loader
        self.test_loader_unseen = test_loader_unseen
        self.test_loader_seen = test_loader_seen

        self.criterion = criterion
        self.crossEntropy_Loss = nn.NLLLoss()

        self.all_attrs = all_attrs
        self.lr = lr
        self.epoch = epoch
        self.SIGMA = SIGMA
        self.save_path = save_path
        self.save_every = save_every
        self.ifsample = ifsample

        self.unseen_labels = np.array([30, 9, 41, 50, 31, 7, 34, 24, 23, 47
                                       ]) - 1
        self.seen_labels = np.array([
            22, 49, 14, 45, 25, 39, 18, 6, 44, 29, 19, 16, 1, 32, 33, 26, 21,
            37, 43, 36, 8, 38, 2, 15, 27, 42, 35, 13, 40, 5, 20, 28, 10, 4, 3,
            46, 48, 17, 11, 12
        ]) - 1
        if iftest:
            log_dir = '{}/log'.format(self.save_path)
            general.logger_setup(log_dir, 'results__')
Esempio n. 2
0
 def training(self, checkpoint = -1):
     log_dir = '{}/log'.format(self.save_path)
     general.logger_setup(log_dir)
     if checkpoint >= 0:
         file_classifier = 'Checkpoint_{}_classifier.pth.tar'.format(checkpoint)
         classifier_path = os.path.join(self.save_path, file_classifier)
         classifier_checkpoint = torch.load(classifier_path)
         self.classifier.load_state_dict(classifier_checkpoint['state_dict'])
     self.classifier.train()
     classifier_optim = optim.Adam(self.classifier.parameters(), lr = self.lr)
     classifier_scheduler = StepLR(classifier_optim, step_size=10000, gamma=0.5)
     if torch.cuda.is_available():
         self.classifier = self.classifier.cuda()
     for epoch in range(checkpoint+1, self.epoch):
         step = 0
      
         train_data_iter = iter(self.train_loader)
         for i_batch in range(len(train_data_iter)):
             sample_batched = next(train_data_iter)
             input_data = sample_batched['feature']
             input_label = sample_batched['label']
             batch_size = input_data.size()[0]
             if torch.cuda.is_available():
                 input_data = input_data.float().cuda()
                 input_label = input_label.cuda().long().view(-1)
             self.classifier.zero_grad()
             pred_label = self.classifier(input_data)
             cls_loss = self.criterion(pred_label, input_label)
             cls_loss.backward()
             classifier_optim.step()
             step += 1
             if (step + 1) % 500 == 0:
                 logging.info("Epoch: [%d/%d], Step: [%d/%d],  Cls Loss: %.4f" %
                       (epoch, self.epoch, step , len(self.train_loader), cls_loss.data.item()))
             
         if epoch % self.save_every ==0:
             file_classifier = 'Checkpoint_{}_classifier.pth.tar'.format(epoch)
             file_name_classifier = os.path.join(self.save_path, file_classifier)
             self.save_checkpoint(
                 {'epoch':epoch, 
                  'state_dict': self.classifier.state_dict(), 
                  'optimizer': classifier_optim.state_dict()}, 
                  file_name_classifier)    
Esempio n. 3
0
    def training(self):
        log_dir = '{}/log'.format(self.save_path)
        general.logger_setup(log_dir)

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()
        self.attr_encoder.train()
        self.attr_decoder.train()

        enc_optim = optim.Adam(self.encoder.parameters(), lr=self.lr)
        attr_enc_optim = optim.Adam(self.attr_encoder.parameters(), lr=self.lr)
        attr_dec_optim = optim.Adam(self.attr_decoder.parameters(), lr=self.lr)
        dec_optim = optim.Adam(self.decoder.parameters(), lr=self.lr)
        dis_optim = optim.Adam(self.discriminator.parameters(),
                               lr=0.5 * self.lr)

        attr_enc_scheduler = StepLR(attr_enc_optim, step_size=30, gamma=0.5)
        attr_dec_scheduler = StepLR(attr_dec_optim, step_size=30, gamma=0.5)
        enc_scheduler = StepLR(enc_optim, step_size=30, gamma=0.5)
        dec_scheduler = StepLR(dec_optim, step_size=30, gamma=0.5)
        dis_scheduler = StepLR(dis_optim, step_size=30, gamma=0.5)

        if torch.cuda.is_available():
            self.attr_encoder = self.attr_encoder.cuda()
            self.attr_decoder = self.attr_decoder.cuda()
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.discriminator = self.discriminator.cuda()

        one = torch.Tensor([1])
        mone = one * -1

        if torch.cuda.is_available():
            one = one.cuda()
            mone = mone.cuda()

        for epoch in range(self.epoch):
            step = 0

            for i_batch, sample_batched in enumerate(self.train_loader):

                input_data = sample_batched['feature']
                input_attr = sample_batched['attr'].squeeze()
                input_label = sample_batched['label']
                ipdb.set_trace()
                if torch.cuda.is_available():
                    input_data = input_data.float().cuda()
                    input_attr = input_attr.float().cuda()
                    input_label = input_label.cuda()

                self.attr_encoder.zero_grad()
                self.attr_decoder.zero_grad()
                self.encoder.zero_grad()
                self.decoder.zero_grad()
                self.discriminator.zero_grad()

                # ======== Train Discriminator ======== #

                frozen_params(self.attr_encoder)
                frozen_params(self.attr_decoder)
                frozen_params(self.decoder)
                frozen_params(self.encoder)
                free_params(self.discriminator)

                mu, sigma = self.attr_encoder(input_attr)

                for j in range(1):
                    #z_fake = torch.randn(input_data.size()[0], mu.size()[1]) * sigma.cpu().data + mu.cpu().data
                    z_fake = self.reparametrize(mu, sigma)

                    if torch.cuda.is_available():
                        z_fake = z_fake.cuda()

                    z_fake_input = torch.cat((z_fake, input_attr), 1)

                    d_fake = self.discriminator(z_fake_input)

                    z_real = self.encoder(input_data)

                    z_real_input = torch.cat((z_real, input_attr), 1)

                    d_real = self.discriminator(z_real_input)

                    torch.log(d_fake).mean().backward(mone)
                    torch.log(1 - d_real).mean().backward(mone)

                    dis_optim.step()

                # ======== Train Generator ======== #

                free_params(self.decoder)
                free_params(self.encoder)
                free_params(self.attr_encoder)
                free_params(self.attr_decoder)
                frozen_params(self.discriminator)

                batch_size = input_data.size()[0]

                z_real = self.encoder(input_data)
                z_real_input = torch.cat((z_real, input_attr), 1)
                x_recon = self.decoder(z_real_input)

                mu, sigma = self.attr_encoder(input_attr)
                z_fake = self.reparametrize(mu, sigma)

                attr_recon = self.attr_decoder(z_fake)

                d_real = self.discriminator(
                    torch.cat(
                        (self.encoder(Variable(input_data).data), input_attr),
                        1))

                recon_loss = self.criterion(x_recon, input_data)
                attr_recon_loss = self.criterion(attr_recon, input_attr)

                d_loss = self.LAMBDA * (torch.log(d_real)).mean()

                recon_loss.backward(one, retain_graph=True)
                d_loss.backward(mone, retain_graph=True)
                attr_recon_loss.backward(one, retain_graph=True)

                enc_optim.step()
                dec_optim.step()
                attr_enc_optim.step()
                attr_dec_optim.step()

                step += 1

                if (step + 1) % 50 == 0:
                    logging.info(
                        "Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.6f, attr_recon Loss :%.6f, D Loss: %.6f"
                        % (epoch + 1, self.epoch, step + 1,
                           len(self.train_loader), recon_loss.data.item(),
                           attr_recon_loss.data.item(), d_loss.data.item()))
                    print("d_real = {}, d_fake = {} ".format(
                        d_real.mean(), d_fake.mean()))

            if epoch % self.save_every == 0:

                file_encoder = 'Checkpoint_{}_Enc.pth.tar'.format(epoch)
                file_decoder = 'Checkpoint_{}_Dec.pth.tar'.format(epoch)
                file_attr_encoder = 'Checkpoint_{}_AttrEnc.pth.tar'.format(
                    epoch)
                file_discriminator = 'Checkpoint_{}_Disc.pth.tar'.format(epoch)

                file_name_enc = os.path.join(self.save_path, file_encoder)
                file_name_attr_enc = os.path.join(self.save_path,
                                                  file_attr_encoder)
                file_name_dec = os.path.join(self.save_path, file_decoder)
                file_name_disc = os.path.join(self.save_path,
                                              file_discriminator)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.encoder.state_dict(),
                        'optimizer': enc_optim.state_dict()
                    }, file_name_enc)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.attr_encoder.state_dict(),
                        'optimizer': attr_enc_optim.state_dict()
                    }, file_name_attr_enc)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.decoder.state_dict(),
                        'optimizer': dec_optim.state_dict()
                    }, file_name_dec)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.discriminator.state_dict(),
                        'optimizer': dis_optim.state_dict()
                    }, file_name_disc)

                self.testing(epoch)
Esempio n. 4
0
    def training(self, checkpoint=-1):
        log_dir = '{}/log'.format(self.save_path)
        general.logger_setup(log_dir)

        if checkpoint >= 0:
            file_encoder = 'Checkpoint_{}_Enc.pth.tar'.format(checkpoint)
            file_decoder = 'Checkpoint_{}_Dec.pth.tar'.format(checkpoint)
            file_attr_encoder = 'Checkpoint_{}_attr_Enc.pth.tar'.format(
                checkpoint)
            file_classifier = 'Checkpoint_{}_classifier.pth.tar'.format(
                checkpoint)

            enc_path = os.path.join(self.save_path, file_encoder)
            dec_path = os.path.join(self.save_path, file_decoder)
            attr_enc_path = os.path.join(self.save_path, file_attr_encoder)
            classifier_path = os.path.join(self.save_path, file_classifier)

            enc_checkpoint = torch.load(enc_path)
            self.encoder.load_state_dict(enc_checkpoint['state_dict'])

            dec_checkpoint = torch.load(dec_path)
            self.decoder.load_state_dict(dec_checkpoint['state_dict'])

            attr_enc_checkpoint = torch.load(attr_enc_path)
            self.attr_encoder.load_state_dict(
                attr_enc_checkpoint['state_dict'])

            classifier_checkpoint = torch.load(classifier_path)
            self.classifier.load_state_dict(
                classifier_checkpoint['state_dict'])

        self.encoder.train()
        self.decoder.train()
        self.attr_encoder.train()
        self.classifier.train()

        enc_optim = optim.Adam(self.encoder.parameters(), lr=self.lr)
        dec_optim = optim.Adam(self.decoder.parameters(), lr=self.lr)
        attr_enc_optim = optim.Adam(self.attr_encoder.parameters(), lr=self.lr)
        classifier_optim = optim.Adam(self.classifier.parameters(), lr=self.lr)

        enc_scheduler = StepLR(enc_optim, step_size=10000, gamma=0.5)
        dec_scheduler = StepLR(dec_optim, step_size=10000, gamma=0.5)
        attr_enc_scheduler = StepLR(attr_enc_optim, step_size=10000, gamma=0.5)
        classifier_scheduler = StepLR(classifier_optim,
                                      step_size=10000,
                                      gamma=0.5)

        if torch.cuda.is_available():

            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.attr_encoder = self.attr_encoder.cuda()
            self.classifier = self.classifier.cuda()

        for epoch in range(checkpoint + 1, self.epoch):
            step = 0

            train_data_iter = iter(self.train_loader)
            #test_data_iter = itertools.cycle(self.test_loader)

            for i_batch in range(
                    len(train_data_iter)
            ):  #for i_batch, sample_batched in enumerate(self.train_loader):
                sample_batched = next(train_data_iter)
                #test_sample_batched = next(test_data_iter)

                input_data = sample_batched['feature']
                input_label = sample_batched['label']
                input_attr = sample_batched['attr']

                #test_input_label = test_sample_batched['label']
                #test_input_attr = test_sample_batched['attr']

                batch_size = input_data.size()[0]
                if torch.cuda.is_available():
                    input_data = input_data.float().cuda()
                    input_label = input_label.cuda()
                    input_attr = input_attr.float().cuda()

                    #test_input_label = test_input_label.cuda()
                    #test_input_attr = test_input_attr.float().cuda()

                self.encoder.zero_grad()
                self.decoder.zero_grad()
                self.attr_encoder.zero_grad()
                self.classifier.zero_grad()
                # ======== Train Generator ======== #

                z = self.encoder(input_data)
                x_recon = self.decoder(z)
                recon_loss = self.criterion(x_recon, input_data)

                mu, sigma = self.attr_encoder(input_attr)
                z_fake = self.reparametrize(mu, sigma)

                #mu_test, sigma_test = self.attr_encoder(test_input_attr)
                #z_fake_test = self.reparametrize(mu_test, sigma_test)

                # ======== MMD Kernel Loss ======== #

                if torch.cuda.is_available():
                    z_fake = z_fake.cuda()
                    #z_fake_test = z_fake_test.cuda()

                mmd_loss = mmd_utils.mmd(z[0], z_fake[0], 2)
                for i in range(1, batch_size):
                    mmd_loss += mmd_utils.mmd(z[i], z_fake[i], 2)

                margin_loss = mmd_utils.mmd(z_fake[0], z_fake[0], 2)
                zzz = torch.cat((z, z_fake), 0).view(-1, z.shape[2])
                zzz_label = torch.cat(
                    (input_label, input_label)).view(-1).long()
                pred_zzz_label = self.classifier(zzz)
                cls_loss = self.crossEntropy_Loss(pred_zzz_label, zzz_label)
                '''
                for i in range(0, batch_size):
                    for j in range(i, batch_size):
                        if input_label[i,0] != input_label[j,0]:
                            margin_loss += mmd_utils.mmd(z[i], z_fake[j], 2)
                '''
                #total_loss =  recon_loss + mmd_loss *5.0 - 0.01*torch.log(margin_loss)
                total_loss = recon_loss + mmd_loss * 0.5 + cls_loss
                total_loss.backward()

                enc_optim.step()
                dec_optim.step()
                attr_enc_optim.step()
                classifier_optim.step()
                step += 1

                if (step + 1) % 50 == 0:
                    logging.info(
                        "Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.4f MMD Loss: %.4f, Cls Loss: %.4f"
                        % (epoch, self.epoch, step, len(
                            self.train_loader), recon_loss.data.item(),
                           mmd_loss.data.item(), cls_loss.data.item()))
            if epoch % self.save_every == 0 and epoch > 0:

                file_encoder = 'Checkpoint_{}_Enc.pth.tar'.format(epoch)
                file_decoder = 'Checkpoint_{}_Dec.pth.tar'.format(epoch)
                file_attr_encoder = 'Checkpoint_{}_attr_Enc.pth.tar'.format(
                    epoch)
                file_classifier = 'Checkpoint_{}_classifier.pth.tar'.format(
                    epoch)

                file_name_enc = os.path.join(self.save_path, file_encoder)
                file_name_dec = os.path.join(self.save_path, file_decoder)
                file_name_attr_enc = os.path.join(self.save_path,
                                                  file_attr_encoder)
                file_name_classifier = os.path.join(self.save_path,
                                                    file_classifier)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.encoder.state_dict(),
                        'optimizer': enc_optim.state_dict()
                    }, file_name_enc)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.decoder.state_dict(),
                        'optimizer': dec_optim.state_dict()
                    }, file_name_dec)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.attr_encoder.state_dict(),
                        'optimizer': attr_enc_optim.state_dict()
                    }, file_name_attr_enc)
                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.classifier.state_dict(),
                        'optimizer': classifier_optim.state_dict()
                    }, file_name_classifier)
Esempio n. 5
0
    def training_cls(self, checkpoint=-1):
        log_dir = '{}/log'.format(self.save_path)
        general.logger_setup(log_dir)

        if checkpoint >= 0:
            file_encoder = 'Checkpoint_{}_Enc.pth.tar'.format(checkpoint)
            file_decoder = 'Checkpoint_{}_Dec.pth.tar'.format(checkpoint)
            file_attr_encoder = 'Checkpoint_{}_attr_Enc.pth.tar'.format(
                checkpoint)
            file_attr_decoder = 'Checkpoint_{}_attr_Dec.pth.tar'.format(
                checkpoint)

            enc_path = os.path.join(self.save_path, file_encoder)
            dec_path = os.path.join(self.save_path, file_decoder)
            attr_enc_path = os.path.join(self.save_path, file_attr_encoder)
            attr_dec_path = os.path.join(self.save_path, file_attr_decoder)

            enc_checkpoint = torch.load(enc_path)
            self.encoder.load_state_dict(enc_checkpoint['state_dict'])

            dec_checkpoint = torch.load(dec_path)
            self.decoder.load_state_dict(dec_checkpoint['state_dict'])

            attr_enc_checkpoint = torch.load(attr_enc_path)
            self.attr_encoder.load_state_dict(
                attr_enc_checkpoint['state_dict'])

            attr_dec_checkpoint = torch.load(attr_dec_path)
            self.attr_decoder.load_state_dict(
                attr_dec_checkpoint['state_dict'])

        self.encoder.train()
        self.decoder.train()
        self.attr_encoder.train()
        self.attr_decoder.train()

        enc_optim = optim.Adam(self.encoder.parameters(), lr=self.lr)
        dec_optim = optim.Adam(self.decoder.parameters(), lr=self.lr)
        attr_enc_optim = optim.Adam(self.attr_encoder.parameters(), lr=self.lr)
        attr_dec_optim = optim.Adam(self.attr_decoder.parameters(), lr=self.lr)

        enc_scheduler = StepLR(enc_optim, step_size=10000, gamma=0.5)
        dec_scheduler = StepLR(dec_optim, step_size=10000, gamma=0.5)
        attr_enc_scheduler = StepLR(attr_enc_optim, step_size=10000, gamma=0.5)
        attr_dec_scheduler = StepLR(attr_dec_optim, step_size=10000, gamma=0.5)

        if torch.cuda.is_available():
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.attr_encoder = self.attr_encoder.cuda()
            self.attr_decoder = self.attr_decoder.cuda()

        for epoch in range(checkpoint + 1, self.epoch):
            step = 0
            train_data_iter = iter(self.train_loader)
            for i_batch, sample_batched in enumerate(self.train_loader):
                #for i_batch in range(len(train_data_iter)):
                #sample_batched = next(train_data_iter)
                input_data = sample_batched['feature']
                input_label = sample_batched['label']
                input_attr = sample_batched['attr']

                batch_size = input_data.size()[0]
                if torch.cuda.is_available():
                    input_data = input_data.float().cuda()
                    input_label = input_label.cuda()
                    input_attr = input_attr.float().cuda()

                self.encoder.zero_grad()
                self.decoder.zero_grad()
                self.attr_encoder.zero_grad()
                # ======== Train Generator ======== #

                if self.ifsample:
                    m, s = self.encoder(input_data)
                    z = self.reparametrize(m, s)
                else:
                    z = self.encoder(input_data)

                x_recon = self.decoder(z)
                recon_loss = self.criterion(x_recon, input_data)

                mu, sigma = self.attr_encoder(input_attr)
                z_fake = self.reparametrize(mu, sigma)
                if torch.cuda.is_available():
                    z_fake = z_fake.cuda()

                z_input = torch.cat((z_fake, z), 0)
                attr_fake = self.attr_decoder(z_input)

                # ======== MMD Kernel Loss ========
                '''
                x_fake = self.decoder(z_fake)
                mmd_loss_x = mmd_utils.mmd(x_fake[0], input_data[0], 2)
                for i in range(1, batch_size):
                    mmd_loss_x +=  mmd_utils.mmd(x_fake[i], input_data[i], 2)
                '''
                mmd_loss = mmd_utils.mmd(z[0], z_fake[0], 2)
                for i in range(1, batch_size):
                    mmd_loss += mmd_utils.mmd(z[i], z_fake[i], 2)

                attr_loss = self.criterion(
                    attr_fake, torch.cat((input_attr, input_attr), 0))

                total_loss = recon_loss * 10.0 + mmd_loss * 1.0 + attr_loss * 10.0
                total_loss.backward()

                enc_optim.step()
                dec_optim.step()
                attr_enc_optim.step()
                attr_dec_optim.step()
                step += 1

                if (step + 1) % 50 == 0:
                    logging.info(
                        "Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.4f MMDz Loss: %.4f, attr_Recon Loss: %.4f"
                        % (epoch, self.epoch, step, len(
                            self.train_loader), recon_loss.data.item(),
                           mmd_loss.data.item(), attr_loss.data.item()))
            if epoch % self.save_every == 0:

                file_encoder = 'Checkpoint_{}_Enc.pth.tar'.format(epoch)
                file_decoder = 'Checkpoint_{}_Dec.pth.tar'.format(epoch)
                file_attr_encoder = 'Checkpoint_{}_attr_Enc.pth.tar'.format(
                    epoch)
                file_attr_decoder = 'Checkpoint_{}_attr_Dec.pth.tar'.format(
                    epoch)

                file_name_enc = os.path.join(self.save_path, file_encoder)
                file_name_dec = os.path.join(self.save_path, file_decoder)
                file_name_attr_enc = os.path.join(self.save_path,
                                                  file_attr_encoder)
                file_name_attr_dec = os.path.join(self.save_path,
                                                  file_attr_decoder)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.encoder.state_dict(),
                        'optimizer': enc_optim.state_dict()
                    }, file_name_enc)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.decoder.state_dict(),
                        'optimizer': dec_optim.state_dict()
                    }, file_name_dec)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.attr_encoder.state_dict(),
                        'optimizer': attr_enc_optim.state_dict()
                    }, file_name_attr_enc)

                self.save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': self.attr_encoder.state_dict(),
                        'optimizer': attr_dec_optim.state_dict()
                    }, file_name_attr_dec)