Esempio n. 1
0
    def train_gan(self, epochs, batch_size, sample_interval, train_data):

        # Create labels for real and fake data
        real = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # Get batch of real data
            real_seqs = get_batch(train_data, batch_size)
            # Generate batch of fake data using random noise
            noise = np.random.normal(0, 1, (batch_size, self.model.latent_dim))
            gen_seqs = self.model.generator.predict(noise)

            # Train the discriminator to accept real data and reject fake data
            d_loss_real = self.model.discriminator.train_on_batch(
                real_seqs, real)
            d_loss_fake = self.model.discriminator.train_on_batch(
                gen_seqs, fake)

            # Train the generator such that when it takes random noise as an
            # input, it will produce fake data which the discriminator accepts
            # as real

            noise = np.random.normal(0, 1, (batch_size, self.model.latent_dim))
            g_loss = self.model.gan.train_on_batch(noise, real)

            if epoch % sample_interval == 0:
                print("""%d [DiscLoss/Acc Real: (%10f, %10f)] 
                       [DiscLoss/Acc Fake: (%10f, %10f)] 
                       [DiscAcc %10f][GenLoss = %10f]""" %
                      (epoch, d_loss_real[0], d_loss_real[1], d_loss_fake[0],
                       d_loss_fake[1], 0.5 *
                       (d_loss_real[1] + d_loss_fake[1]), g_loss))

                self.disc_loss_r.append(d_loss_real)
                self.disc_loss_f.append(d_loss_fake)

                self.gen_loss.append(g_loss)
                sample_image(self.model, epoch, real_seqs, self.path)
            if (epoch % 1000 == 0):
                self.save_models(self.path, epoch, self.model.generator,
                                 self.model.discriminator)

        self.savedata(self.path, train_data)
        self.showLoss(self.path, save=True)
Esempio n. 2
0
def train():
    os.makedirs("images", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

    # get configs and dataloader
    opt = parse_args()
    data_loader = mnist_loader(opt)

    # Initialize generator and discriminator
    generator = Generator(opt)
    discriminator = Discriminator(opt)

    # Loss function
    adversarial_loss = torch.nn.MSELoss()

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    for epoch in range(opt.epochs):
        for i, (imgs, labels) in enumerate(data_loader):

            # Adversarial ground truths
            valid = Variable(FloatTensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(FloatTensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

            # Configure input
            z = Variable(FloatTensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
            gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, imgs.shape[0])))
            labels = Variable(labels.type(LongTensor))

            real_imgs = Variable(imgs.type(FloatTensor))
            gen_imgs = generator(z, gen_labels)

            # ------------------
            # Train Discriminator
            # ------------------

            optimizer_D.zero_grad()

            real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # ------------------
            # Train Generator
            # ------------------

            if i % opt.n_critic == 0:
                optimizer_G.zero_grad()

                # Loss for generator
                g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)

                # Update parameters
                g_loss.backward()
                optimizer_G.step()

            # ------------------
            # Log Information
            # ------------------

            print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                  (epoch, opt.epochs, i, len(data_loader), d_loss.item(), g_loss.item()))

            batches_done = epoch * len(data_loader) + i
            if batches_done % opt.sample_interval == 0:
                sample_image(opt, 10, batches_done, generator, FloatTensor, LongTensor)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(generator.state_dict(), "checkpoints/generator_%d.pth" % epoch)
                # torch.save(discriminator.state_dict(), "checkpoints/discriminator_%d.pth" % epoch)

    torch.save(generator.state_dict(), "checkpoints/generator_done.pth")
    print("Training Process has been Done!")
Esempio n. 3
0
def train(opt, source_model=None):
    opt.use_dynamics = not opt.no_dynamics
    train_data_s, train_data_t, test_data_s, test_data_t = joblib.load(
        os.path.join(opt.data_path, 'all_data.pkl'))
    model, losses, optimizer_G, optimizer_D, metrics_dict = init_training(opt)
    if opt.domain == 'source':
        source_model = model

    train_data = train_data_s if opt.domain == 'source' else train_data_t
    test_data = test_data_s if opt.domain == 'source' else test_data_t
    for epoch in range(opt.n_epochs):
        gen = grouper(np.random.permutation(len(train_data.obs)),
                      opt.batch_size)
        num_batches = int(np.ceil(len(train_data.obs) / opt.batch_size))

        for batch_idx, data_idxs in enumerate(gen):
            data_idxs = list(filter(None, data_idxs))
            obs = train_data.obs[data_idxs]
            acs = train_data.acs[data_idxs]
            next_obs = train_data.obs_[data_idxs]
            if opt.cuda:
                obs, acs, next_obs = obs.cuda(), acs.cuda(), next_obs.cuda()

            # Adversarial ground truths
            valid = Tensor(obs.shape[0], 1).fill_(1.0).detach()
            fake = Tensor(obs.shape[0], 1).fill_(0.0).detach()

            # Configure input
            real_imgs = obs.type(Tensor)
            real_next_imgs = next_obs.type(Tensor)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            encoded_imgs = model['enc'](real_imgs)
            if opt.domain == 'source':
                # detached, i.e. policy training does not affect encoder
                pred_acs = source_model['pol'](encoded_imgs.detach())
            decoded_imgs = model['dec'](encoded_imgs)
            if opt.use_dynamics:
                encoded_next_imgs = source_model['dyn'](encoded_imgs, acs)
                decoded_next_imgs = model['dec_next'](encoded_next_imgs)
            if opt.domain == 'target':
                with torch.no_grad():
                    pred_acs = source_model['pol'](encoded_imgs.detach())
                    ac_loss = losses['action'](pred_acs, acs)

            # Loss measures generator's ability to fool the discriminator
            adv_loss = losses['adversarial'](model['discr'](encoded_imgs),
                                             valid)
            pix_loss = losses['pixelwise'](decoded_imgs, real_imgs)

            if opt.domain == 'source':
                ac_loss = losses['action'](pred_acs, acs)
            if opt.use_dynamics:
                pix_next_loss = losses['pixelwise'](decoded_next_imgs,
                                                    real_next_imgs)
                g_loss = opt.adv_coef * adv_loss + (
                    1 - opt.adv_coef) / 2 * pix_loss + (
                        1 - opt.adv_coef) / 2 * pix_next_loss
            else:
                g_loss = opt.adv_coef * adv_loss + (1 -
                                                    opt.adv_coef) * pix_loss

            if opt.domain == 'source':
                g_loss = 0.5 * g_loss + 0.5 * ac_loss

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Sample noise as discriminator ground truth
            if opt.domain == 'source':
                z = Tensor(
                    np.random.normal(0, 1, (obs.shape[0], opt.latent_dim)))
            elif opt.domain == 'target':
                obs_s = train_data_s.obs[data_idxs].cuda(
                ) if opt.cuda else train_data_s.obs[data_idxs]
                z = source_model['enc'](obs_s.type(Tensor))

            # Measure discriminator's ability to classify real from generated samples
            real_loss = losses['adversarial'](model['discr'](z), valid)
            fake_loss = losses['adversarial'](model['discr'](
                encoded_imgs.detach()), fake)
            d_loss = 0.5 * (real_loss + fake_loss)

            d_loss.backward()
            optimizer_D.step()

            if batch_idx % opt.log_interval == 0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [adv loss: %f] [pix loss: %f] [ac loss: %f] [G loss: %f]"
                    % (epoch, opt.n_epochs, batch_idx, num_batches,
                       d_loss.item(), adv_loss.item(), pix_loss.item(),
                       ac_loss.item(), g_loss.item()))

            batches_done = epoch * num_batches + batch_idx
            if batches_done % opt.sample_interval == 0:
                sample_image(model['dec'], n_row=10, batches_done=batches_done)

            metrics_dict['g_losses'].append(g_loss.item())
            metrics_dict['pix_losses'].append(pix_loss.item())
            metrics_dict['adv_losses'].append(adv_loss.item())
            metrics_dict['d_losses'].append(d_loss.item())
            metrics_dict['ac_losses'].append(ac_loss.item())
            if opt.use_dynamics:
                metrics_dict['pix_next_losses'].append(pix_next_loss.item())

        with torch.no_grad():
            # careful, all test data may be too large for a gpu
            test_obs = test_data.obs.cuda() if opt.cuda else test_data.obs
            if opt.domain == 'source':
                rf_acc, _ = z_separation_accuracy(model['enc'], test_obs)
            elif opt.domain == 'target':
                test_obs_s = test_data_s.obs.cuda(
                ) if opt.cuda else test_data_s.obs
                rf_acc, _ = z_separation_accuracy(model['enc'], test_obs,
                                                  source_model['enc'],
                                                  test_obs_s)
            pred_acs = source_model['pol'](model['enc'](test_obs.type(Tensor)))
            metrics_dict['rf_z_sep_accs'].append(rf_acc)
            pol_acc = (torch.max(pred_acs.cpu(),
                                 1)[1] == test_data.acs).float().mean().item()
            metrics_dict['pol_accs'].append(pol_acc)

    return model, metrics_dict
Esempio n. 4
0
        gen_imgs = generator(z, label_input, code_input)
        _, pred_label, pred_code = discriminator(gen_imgs)

        info_loss = params.lambda_cat * categorical_loss(
            pred_label, gt_labels) + params.lambda_con * continuous_loss(
                pred_code, code_input)

        info_loss.backward()
        optimizer_info.step()

        # --------------
        # Log Progress
        # --------------
        if i == len(dataloader) - 1:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [info loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(),
                   g_loss.item(), info_loss.item()))
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            utils.sample_image(generator=generator,
                               n_row=10,
                               batches_done=batches_done)

        torch.save(
            {
                'generator': generator.state_dict(),
                'discriminator': discriminator.state_dict(),
                'parameters': opt
            }, './trained_models/model_final_{}'.format(opt.n_epochs))
Esempio n. 5
0
    netG.cuda()

encoder = BERTEncoder()

if opt.sample == 'noshuffle':
    print('sampling images based on fixed sequence of categories ...')

    sample_batch_size = 50
    sample_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=sample_batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
    )

    sample_image(netG, encoder, sample_batch_size, 5, 0, sample_dataloader, opt)
    #sample_final_image(netG, encoder, 100, sample_batch_size, sample_dataloader, opt)
    exit(0)
elif opt.sample == 'shuffle':
    print('sampling images based on shuffled testsets ...')

    eval_dataset = None
    if opt.dataset == 'imagenet':
        eval_dataset = val_dataset
    elif opt.dataset == 'cifar10':
        eval_dataset = train_dataset
    elif opt.dataset == 'coco':
        eval_dataset = train_dataset

    sample_batch_size = opt.batchSize
    sample_dataloader = None