Example #1
0
def train_GAN(conf_data):
    """Training Process for GAN.
    
    Parameters
    ----------
    conf_data: dict
        Dictionary containing all parameters and objects.       

    Returns
    -------
    conf_data: dict
        Dictionary containing all parameters and objects.       

    """
    seq = conf_data['GAN_model']['seq']
    if seq == 1:
        pre_epoch_num = conf_data['generator']['pre_epoch_num']
        GENERATED_NUM = 10000
        EVAL_FILE = 'eval.data'
        POSITIVE_FILE = 'real.data'
        NEGATIVE_FILE = 'gene.data'
    temp = 1  #TODO Determines how many times is the discriminator updated. Take this as a value input
    epochs = int(conf_data['GAN_model']['epochs'])
    if seq == 0:
        dataloader = conf_data['data_learn']
    mini_batch_size = int(conf_data['GAN_model']['mini_batch_size'])
    data_label = int(conf_data['GAN_model']['data_label'])
    cuda = conf_data['cuda']
    g_latent_dim = int(conf_data['generator']['latent_dim'])
    classes = int(conf_data['GAN_model']['classes'])

    w_loss = int(conf_data['GAN_model']['w_loss'])

    clip_value = float(conf_data['GAN_model']['clip_value'])
    n_critic = int(conf_data['GAN_model']['n_critic'])

    lambda_gp = int(conf_data['GAN_model']['lambda_gp'])

    log_file = open(conf_data['performance_log'] + "/log.txt", "w+")
    #Covert these to parameters of the config data
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    conf_data['Tensor'] = Tensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
    conf_data['LongTensor'] = LongTensor
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    conf_data['FloatTensor'] = FloatTensor

    conf_data['epochs'] = epochs

    #print ("Just before training")
    if seq == 1:  #TODO: Change back to 1
        target_lstm = TargetLSTM(conf_data['GAN_model']['vocab_size'],
                                 conf_data['generator']['embedding_dim'],
                                 conf_data['generator']['hidden_dim'],
                                 conf_data['cuda'])
        if cuda == True:
            target_lstm = target_lstm.cuda()
        conf_data['target_lstm'] = target_lstm
        gen_data_iter = GenDataIter('real.data', mini_batch_size)
        generator = conf_data['generator_model']
        discriminator = conf_data['discriminator_model']
        g_loss_func = conf_data['generator_loss']
        d_loss_func = conf_data['discriminator_loss']
        optimizer_D = conf_data['discriminator_optimizer']
        optimizer_G = conf_data['generator_optimizer']
        #print('Pretrain with MLE ...')
        for epoch in range(pre_epoch_num):  #TODO: Change the range
            loss = train_epoch(generator, gen_data_iter, g_loss_func,
                               optimizer_G, conf_data, 'g')
            print('Epoch [%d] Model Loss: %f' % (epoch, loss))
            generate_samples(generator, mini_batch_size, GENERATED_NUM,
                             EVAL_FILE, conf_data)
            eval_iter = GenDataIter(EVAL_FILE, mini_batch_size)
            loss = eval_epoch(target_lstm, eval_iter, g_loss_func, conf_data)
            print('Epoch [%d] True Loss: %f' % (epoch, loss))

        dis_criterion = d_loss_func
        dis_optimizer = optimizer_D
        #TODO: Understand why the below two code line were there ?
        # if conf_data['cuda']:
        #     dis_criterion = dis_criterion.cuda()

        #print('Pretrain Dsicriminator ...')
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                    mini_batch_size)
        for epoch in range(5):  #TODO: change back 5
            generate_samples(generator, mini_batch_size, GENERATED_NUM,
                             NEGATIVE_FILE, conf_data)
            dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                        mini_batch_size)
            for _ in range(3):  #TODO: change back 3
                loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                                   dis_optimizer, conf_data, 'd')
                print('Epoch [%d], loss: %f' % (epoch, loss))
        conf_data['generator_model'] = generator
        conf_data['discriminator_model'] = discriminator
        torch.save(conf_data['generator_model'].state_dict(),
                   conf_data['save_model_path'] + '/Seq/' + 'pre_generator.pt')
        torch.save(
            conf_data['discriminator_model'].state_dict(),
            conf_data['save_model_path'] + '/Seq/' + 'pre_discriminator.pt')

        conf_data['rollout'] = Rollout(generator, 0.8)

    for epoch in range(epochs):
        conf_data['epoch'] = epoch
        if seq == 0:
            to_iter = dataloader
        elif seq == 1:  #TODO: Change this back to 1
            to_iter = [1]

        for i, iterator in enumerate(to_iter):
            optimizer_D = conf_data['discriminator_optimizer']
            optimizer_G = conf_data['generator_optimizer']

            generator = conf_data['generator_model']
            discriminator = conf_data['discriminator_model']

            g_loss_func = conf_data['generator_loss']
            d_loss_func = conf_data['discriminator_loss']

            # if aux = 1:

            #print ("Reached here --------------> ")
            conf_data['iterator'] = i
            if seq == 0:

                if data_label == 1:
                    imgs, labels = iterator
                else:
                    imgs = iterator
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0),
                                 requires_grad=False)
                conf_data['valid'] = valid
                fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0),
                                requires_grad=False)
                conf_data['fake'] = fake
                # Configure input
                real_imgs = Variable(imgs.type(Tensor))

                if data_label == 1:
                    labels = Variable(labels.type(LongTensor))
                # Sample noise as generator input
                z = Variable(
                    Tensor(
                        np.random.normal(0, 1, (imgs.shape[0], g_latent_dim))))
                if classes > 0:
                    gen_labels = Variable(
                        LongTensor(np.random.randint(0, classes,
                                                     imgs.shape[0])))
                    conf_data['gen_labels'] = gen_labels
            # elif seq == 1: #If yes seqGAN
            #     # samples = generator.sample(mini_batch_size,conf_data['generator']['sequece_length'])
            #     # zeros = torch.zeros((mini_batch_size,1)).type(LongTensor)
            #     # imgs = Variable(torch.cat([zeros,samples.data]),dim=1)[:,:-1].contiguous() #TODO: change imgs to inps all, to make more sense of the code
            #     # targets = Variable(sample.data).contiguous().view((-1,))
            #     # rewards = rollout.get_reward(sample,16,discriminator)
            #     # rewards = Variable(Tensor(rewards))
            #     # prob = generator.forward(inputs)
            #     # loss = gen_gan_loss(prob)
            #     pass
            #     #optimizer_G

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            if seq == 1:  #TODO change this back to 1
                dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,
                                            mini_batch_size)
            for i in range(
                    temp
            ):  # TODO: Make this a parameter -> for x updates --> I am read the stored models here as well. Should I reamove this ???
                optimizer_D = conf_data['discriminator_optimizer']
                optimizer_G = conf_data['generator_optimizer']

                generator = conf_data['generator_model']
                discriminator = conf_data['discriminator_model']

                g_loss_func = conf_data['generator_loss']
                d_loss_func = conf_data['discriminator_loss']
                if classes <= 0:
                    #print ("Reached here 2 --------------> ")
                    if seq == 0:
                        gen_imgs = generator(z)
                        # Measure discriminator's ability to classify real from generated samples
                        #Real images
                        real_validity = discriminator(real_imgs)
                        #Fake images
                        fake_validity = discriminator(gen_imgs.detach())
                    if seq == 1:
                        generate_samples(generator, mini_batch_size,
                                         GENERATED_NUM, NEGATIVE_FILE,
                                         conf_data)
                        dis_data_iter = DisDataIter(POSITIVE_FILE,
                                                    NEGATIVE_FILE,
                                                    mini_batch_size)
                        loss = train_epoch(discriminator, dis_data_iter,
                                           d_loss_func, optimizer_D, conf_data,
                                           'd')
                        conf_data['d_loss'] = loss
                        #exit()

                else:
                    if seq == 0:
                        gen_imgs = generator(z, gen_labels)
                        real_validity = discriminator(real_imgs, labels)
                        fake_validity = discriminator(gen_imgs.detach(),
                                                      labels)

                if seq == 0:
                    conf_data['gen_imgs'] = gen_imgs
                if seq == 0:
                    if w_loss == 0:
                        real_loss = d_loss_func.loss(real_validity, valid)
                        fake_loss = d_loss_func.loss(fake_validity, fake)
                        d_loss = (real_loss + fake_loss) / 2
                    elif w_loss == 1:
                        d_loss = -d_loss_func.loss(real_validity,
                                                   valid) + d_loss_func.loss(
                                                       fake_validity, fake)
                        if lambda_gp > 0:
                            conf_data['real_data_sample'] = real_imgs.data
                            conf_data['fake_data_sample'] = gen_imgs.data
                            conf_data = compute_gradient_penalty(conf_data)
                            gradient_penalty = conf_data['gradient_penalty']
                            d_loss = d_loss + lambda_gp * gradient_penalty
                    conf_data['d_loss'] = d_loss
                    d_loss.backward()
                    optimizer_D.step()

                if clip_value > 0:
                    # Clip weights of discriminator
                    for p in discriminator.parameters():
                        p.data.clamp_(-clip_value, clip_value)

            # -----------------
            #  Train Generator
            # -----------------
            conf_data['generator_model'] = generator
            conf_data['discriminator_model'] = discriminator

            #Next 4 lines were recently added maybe have to remove this.
            conf_data['optimizer_G'] = optimizer_G
            conf_data['optimizer_D'] = optimizer_D
            conf_data['generator_loss'] = g_loss_func
            conf_data['discriminator_loss'] = d_loss_func
            if seq == 0:
                conf_data['noise'] = z

            if n_critic <= 0:
                conf_data = training_fucntion_generator(conf_data)
            elif n_critic > 0:
                # Train the generator every n_critic iterations
                if i % n_critic == 0:
                    conf_data = training_fucntion_generator(conf_data)
            #exit()

        # print ("------------------ Here (train_GAN.py)")

            if seq == 0:
                batches_done = epoch * len(dataloader) + i
                if batches_done % int(conf_data['sample_interval']) == 0:
                    if classes <= 0:
                        # print ("Here")
                        # print (type(gen_imgs.data[:25]))
                        # print (gen_imgs.data[:25].shape)
                        save_image(gen_imgs.data[:25],
                                   conf_data['result_path'] +
                                   '/%d.png' % batches_done,
                                   nrow=5,
                                   normalize=True)
                    elif classes > 0:
                        sample_image(10, batches_done, conf_data)
        if seq == 0:
            log_file.write("[Epoch %d/%d] [D loss: %f] [G loss: %f] \n" %
                           (epoch, epochs, conf_data['d_loss'].item(),
                            conf_data['g_loss'].item()))
        elif seq == 1:
            # print ("Done")
            log_file.write(
                "[Epoch %d/%d] [D loss: %f] [G loss: %f] \n" %
                (epoch, epochs, conf_data['d_loss'], conf_data['g_loss']))
    conf_data['generator_model'] = generator
    conf_data['discriminator_model'] = discriminator
    conf_data['log_file'] = log_file
    return conf_data
def training_fucntion_generator(conf_data):
    """Training Process for generator network.
    
    Parameters
    ----------
    conf_data: dict
        Dictionary containing all parameters and objects.       

    Returns
    -------
    conf_data: dict
        Dictionary containing all parameters and objects.       

    """
    PRE_EPOCH_NUM = 2
    seq = conf_data['GAN_model']['seq']
    BATCH_SIZE = 64
    GENERATED_NUM = 10000 
    EVAL_FILE = 'eval.data'
    POSITIVE_FILE = 'real.data'
    NEGATIVE_FILE = 'gene.data'
    
    classes = int(conf_data['GAN_model']['classes'])
    w_loss = int(conf_data['GAN_model']['w_loss'])
    g_loss_func = conf_data['generator_loss']
    
    epoch = conf_data['epoch']
    epochs = conf_data['epochs']

    generator = conf_data['generator_model']
    discriminator = conf_data['discriminator_model']
    optimizer_G = conf_data['generator_optimizer']
    mini_batch_size = (conf_data['GAN_model']['mini_batch_size'])

    optimizer_G.zero_grad()

    # Generate a batch of images
    if seq == 0:
        valid = conf_data['valid']
        gen_imgs = conf_data['gen_imgs']
        z = conf_data['noise']
        if classes <= 0:
            #gen_imgs = generator(z)
            # Loss measures generator's ability to fool the discriminator
            validity = discriminator(gen_imgs)
        elif classes > 0:
            gen_labels = conf_data['gen_labels']
            #gen_imgs = generator(z,gen_labels)
            validity = discriminator(gen_imgs, gen_labels)
           
        if w_loss == 1:
            g_loss = -g_loss_func.loss(validity,valid)
        elif w_loss == 0:
            g_loss = g_loss_func.loss(validity,valid) 
        conf_data['g_loss'] = g_loss
        g_loss.backward()
        optimizer_G.step()
    elif seq == 1:
        #print ("Reached Here 3 ---------> ")
        gen_gan_loss = GANLoss()
        rollout = conf_data['rollout']
        target_lstm = conf_data['target_lstm']
        for it in range(1):
            samples = generator.sample(mini_batch_size, conf_data['generator']['sequece_length'])
            # construct the input to the genrator, add zeros before samples and delete the last column
            zeros = torch.zeros((mini_batch_size, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(torch.cat([zeros, samples.data], dim = 1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view((-1,))
            # calculate the reward
            rewards = rollout.get_reward(samples, 16, discriminator)
            rewards = Variable(torch.Tensor(rewards))
            if conf_data['cuda']:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1,))
            prob = generator.forward(inputs)
            rewards = rewards.contiguous().view(-1,)
            loss = gen_gan_loss(prob, targets, rewards)
            optimizer_G.zero_grad()
            loss.backward()
            optimizer_G.step()
        #TODO : Change back. Uncomment and indent till line above to rollout
        #if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
        generate_samples(generator, mini_batch_size, GENERATED_NUM, EVAL_FILE,conf_data)
        #print ("Reached Here 4 ---------> ")
        eval_iter = GenDataIter(EVAL_FILE, mini_batch_size)
        #print ("Reached Here 5 ---------> ")
        loss = eval_epoch(target_lstm, eval_iter, g_loss_func,conf_data)
        conf_data['g_loss']= loss
        #print ("Reached Here 6 ---------> ")
       #print('Batch [%d] True Loss: %f' % (total_batch, loss))
        rollout.update_params()

    #g_loss = g_loss_func.loss(validity, valid)

    # print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, epochs, conf_data['iterator'], 5,
    #                                                 conf_data['d_loss'].item(), g_loss.item()))
    if seq == 0:
        print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, epochs, conf_data['iterator'], len(conf_data['data_learn']),
                                                       conf_data['d_loss'].item(), g_loss.item()))
    elif seq == 1:
        print("[Epoch %d/%d] [Batch %d] [D loss: %f] [G loss: %f]"% (epoch, epochs, conf_data['iterator'],
                                                       conf_data['d_loss'], conf_data['g_loss']))
    #print ("Done")

    conf_data['generator_model'] = generator
    conf_data['generator_optimizer'] = optimizer_G

    conf_data['discriminator_model'] = discriminator
    conf_data['generator_loss'] = g_loss_func
    if seq == 1:
        conf_data['rollout'] = rollout
    return conf_data