Ejemplo n.º 1
0
 def train_epochs(self, train_dataset, dev_dataset):
     for epoch in range(self.args.epochs):  #0,...,epochs-1
         if not self.stop:
             print('\nTraining: Epoch ' + str(epoch + 1))
             self.train(train_dataset)
             #training epoch completes, now do validation
             print('\nValidating: Epoch ' + str(epoch + 1))
             metrics_sum = self.validate(dev_dataset)
             self.dev_metrics.append(metrics_sum)
             #save model if metrics sum goes up
             if self.best_dev_metrics < metrics_sum:
                 self.best_dev_metrics = metrics_sum
                 #save model
                 model_save_path = self.args.save_path + self.args.dataset + '/model/'
                 if not os.path.exists(model_save_path):
                     os.makedirs(model_save_path)
                 self.model.saver.save(self.sess,
                                       model_save_path + 'lostnet.ckpt')
                 self.times_no_improvement = 0
             else:
                 self.times_no_improvement += 1
                 #no improvement in validation metrics for last n iterations, so stop training
                 if self.times_no_improvement == self.args.early_stop:
                     self.stop = True
             #save the train loss plot
             helper.save_plot(self.train_losses,
                              self.args.save_path + self.args.dataset + '/',
                              'train_loss', epoch + 1)
             #save the dev metrics plot
             helper.save_plot(self.dev_metrics,
                              self.args.save_path + self.args.dataset + '/',
                              'dev_metrics', epoch + 1)
         else:
             break
Ejemplo n.º 2
0
 def train_epochs(self, train_corpus, dev_corpus, start_epoch, n_epochs):
     """Trains model for n_epochs epochs"""
     for epoch in range(start_epoch, start_epoch + n_epochs):
         if not self.stop:
             print('\nTRAINING : Epoch ' + str((epoch + 1)))
             self.train(train_corpus)
             # training epoch completes, now do validation
             print('\nVALIDATING : Epoch ' + str((epoch + 1)))
             dev_loss = self.validate(dev_corpus)
             self.dev_losses.append(dev_loss)
             print('validation loss = %.4f' % dev_loss)
             # save model if dev loss goes down
             if self.best_dev_loss == -1 or self.best_dev_loss > dev_loss:
                 self.best_dev_loss = dev_loss
                 helper.save_checkpoint(
                     {
                         'epoch': (epoch + 1),
                         'state_dict': self.model.state_dict(),
                         'best_loss': self.best_dev_loss,
                         'optimizer': self.optimizer.state_dict(),
                     }, self.config.save_path + 'model_best.pth.tar')
                 self.times_no_improvement = 0
             else:
                 self.times_no_improvement += 1
                 # no improvement in validation loss for last n iterations, so stop training
                 if self.times_no_improvement == 5:
                     self.stop = True
             # save the train and development loss plot
             helper.save_plot(self.train_losses, self.config.save_path,
                              'training', epoch + 1)
             helper.save_plot(self.dev_losses, self.config.save_path, 'dev',
                              epoch + 1)
         else:
             break
 def train_epochs(self, train_batches, dev_batches, start_epoch, n_epochs):
     """Trains model for n_epochs epochs"""
     for epoch in range(start_epoch, start_epoch + n_epochs):
         if not self.stop:
             self.train(train_batches, dev_batches, (epoch + 1))
             helper.save_plot(self.train_losses, self.config.save_path, 'training', epoch + 1)
             helper.save_plot(self.dev_losses, self.config.save_path, 'dev', epoch + 1)
         else:
             break
Ejemplo n.º 4
0
 def train_epochs(self, train_corpus, dev_corpus, test_corpus, start_epoch,
                  n_epochs):
     """Trains model for n_epochs epochs"""
     for epoch in range(start_epoch, start_epoch + n_epochs):
         if not self.stop:
             print('\nTRAINING : Epoch ' + str((epoch + 1)))
             self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr'] * self.config.lr_decay \
                 if epoch > start_epoch and 'sgd' in self.config.optimizer else self.optimizer.param_groups[0]['lr']
             if 'sgd' in self.config.optimizer:
                 print('Learning rate : {0}'.format(
                     self.optimizer.param_groups[0]['lr']))
             self.train(train_corpus)
             # training epoch completes, now do validation
             print('\nVALIDATING : Epoch ' + str((epoch + 1)))
             dev_acc = self.validate(dev_corpus)
             self.dev_accuracies.append(dev_acc)
             print('validation acc = %.2f%%' % dev_acc)
             test_acc = self.validate(test_corpus)
             print('validation acc = %.2f%%' % test_acc)
             # save model if dev accuracy goes up
             if self.best_dev_acc < dev_acc:
                 self.best_dev_acc = dev_acc
                 helper.save_checkpoint(
                     {
                         'epoch': (epoch + 1),
                         'state_dict': self.model.state_dict(),
                         'best_acc': self.best_dev_acc,
                         'optimizer': self.optimizer.state_dict(),
                     }, self.config.save_path + 'model_best.pth.tar')
                 self.times_no_improvement = 0
             else:
                 if 'sgd' in self.config.optimizer:
                     self.optimizer.param_groups[0][
                         'lr'] = self.optimizer.param_groups[0][
                             'lr'] / self.config.lrshrink
                     print('Shrinking lr by : {0}. New lr = {1}'.format(
                         self.config.lrshrink,
                         self.optimizer.param_groups[0]['lr']))
                     if self.optimizer.param_groups[0][
                             'lr'] < self.config.minlr:
                         self.stop = True
                 if 'adam' in self.config.optimizer:
                     self.times_no_improvement += 1
                     # early stopping (at 'n'th decrease in accuracy)
                     if self.times_no_improvement == self.config.early_stop:
                         self.stop = True
             # save the train loss and development accuracy plot
             helper.save_plot(self.train_accuracies, self.config.save_path,
                              'training_acc_plot_', epoch + 1)
             helper.save_plot(self.dev_accuracies, self.config.save_path,
                              'dev_acc_plot_', epoch + 1)
         else:
             break
def pgtrain(optims_gen,
            optims_dis,
            generator,
            agent,
            discriminator,
            bsize,
            embed_dim,
            trainSample,
            validSample,
            testSample,
            val_acc_best,
            val_preck_best,
            val_loss_best,
            action_num,
            max_length,
            recom_length,
            gen_ratio=0.1,
            n_epochs=5,
            write_item='click_gen.txt',
            write_target='tar_gen.txt',
            write_reward='reward_gen.txt',
            write_action='action_gen.txt',
            plot_fig=True,
            pretrain=False):
    outputdir = "model_output"
    outputmodelname = "simu.model.pth"
    lrshrink = 5
    minlr = 1e-5

    #Evaluation loss functions
    loss_fn_target = nn.CrossEntropyLoss()
    loss_fn_reward = nn.BCEWithLogitsLoss()
    loss_fn_target.size_average = True
    loss_fn_target.to(device)
    loss_fn_reward.size_average = True
    loss_fn_reward.to(device)

    inner_val_preck_best = val_preck_best
    inner_val_acc_best = val_acc_best
    inner_loss_best = val_loss_best
    epoch = 1
    eval_type = 'valid'
    g_step = 1
    d_step = 1
    evalacc_all = [val_acc_best]
    evalpreck_all = [val_preck_best]
    #Define the optimizer
    optim_fn_gen, optim_params_gen = get_optimizer(optims_gen)
    optim_fn_dis, optim_params_dis = get_optimizer(optims_dis)
    optimizer_dis = optim_fn_dis(
        filter(lambda p: p.requires_grad, discriminator.parameters()),
        **optim_params_dis)
    params_agent = list(agent.parameters())
    params_usr = list(generator.parameters())
    optimizer_agent = optim_fn_gen(
        filter(lambda p: p.requires_grad, params_agent), **optim_params_gen)
    optimizer_usr = optim_fn_gen(filter(lambda p: p.requires_grad, params_usr),
                                 **optim_params_gen)
    while epoch <= n_epochs:
        print('\nAdversarial Policy Gradient Training!')
        # Select subset of trainSample
        subnum = 8000
        for i in range(g_step):
            print('G-step')
            if pretrain:
                print('For Pretraining')
                _ = train_gen_pg_each(generator, agent,
                                      discriminator, epoch, trainSample,
                                      trainSample.length(), optimizer_agent,
                                      optimizer_usr, bsize, embed_dim,
                                      recom_length, max_length, action_num,
                                      device, 0, pretrain)
            else:
                print('For Policy Gradient Update')
                #shuffle_index=np.random.permutation(origin.length())
                _ = train_gen_pg_each(generator, agent, discriminator, epoch,
                                      trainSample, subnum, optimizer_agent,
                                      optimizer_usr, bsize, embed_dim,
                                      recom_length, max_length, action_num,
                                      device, 0.1, pretrain)

        # save model
        # Evaluate without eos, no eos input
        print("Agent evaluation!")
        eval_acc, eval_preck = evaluate_agent(agent,
                                              epoch,
                                              bsize,
                                              recom_length,
                                              validSample,
                                              testSample,
                                              device,
                                              eval_type='valid')
        print("User model evaluation!")
        _ = evaluate_user(generator, epoch, bsize, recom_length, validSample,
                          testSample, loss_fn_target, loss_fn_reward, device,
                          eval_type)
        print("Interaction evaluation!")
        _ = evaluate_interaction(
            (generator, agent), epoch, bsize, recom_length, validSample,
            testSample, loss_fn_target, loss_fn_reward, device, eval_type)

        evalacc_all.append(eval_acc)
        evalpreck_all.append(eval_preck)
        if eval_type == 'valid' and epoch <= n_epochs:
            print('saving model at epoch {0}'.format(epoch))
            if not os.path.exists(outputdir):
                os.makedirs(outputdir)
            torch.save(
                agent.state_dict(),
                os.path.join(outputdir, 'irecGan_agent3.' + outputmodelname))
            torch.save(
                generator.state_dict(),
                os.path.join(outputdir, 'irecGan_gen3.' + outputmodelname))

            inner_val_acc_best = eval_acc
            inner_val_preck_best = eval_preck

        if not pretrain:
            '''
            #Adjust the reward prediction
            print('Reward Adjust')
            trainSample_rewd, validSample_rewd, testSample_rewd=sampleSplit(trainindex, validindex, testindex, Seqlist, numlabel, recom_length)
            _ = train_user_pred(optims_dis, generator, bsize, embed_dim, recom_length + 1, trainSample_rewd, validSample_rewd, testSample_rewd, 'generator with rec', None, None, None, None, only_rewards = True, n_epochs=1)
            #Enable full model training
            for name, param in generator.named_parameters():
                if 'embedding' in name or 'encoder' or 'enc2out' in name:
                    param.requires_grad = True
            '''
            print('\nD-step')
            #Discriminator trainging
            for i in range(d_step):
                shutil.copy('click_gen_real.txt', write_item)
                shutil.copy('reward_gen_real.txt', write_reward)
                shutil.copy('tar_gen_real.txt', write_target)
                shutil.copy('action_gen_real.txt', write_action)
                _, _, _, _ = gen_fake(generator, agent, trainSample, bsize,
                                      embed_dim, device, write_item,
                                      write_target, write_reward, write_action,
                                      action_num, max_length, recom_length)
                clicklist, _ = ReadSeq(write_item, write_reward, write_action,
                                       write_target)
                trainindex_dis, validindex_dis, testindex_dis = split_index(
                    0.7, 0.1, len(clicklist), True)  #Shuffle the index
                trainSample_dis, validSample_dis, testSample_dis = sampleSplit(
                    trainindex_dis, validindex_dis, testindex_dis, clicklist,
                    2, recom_length, 'dis')

                discriminator, _, _ = train_dis(optims_dis, discriminator,
                                                bsize, embed_dim, recom_length,
                                                trainSample_dis,
                                                validSample_dis,
                                                testSample_dis)
        epoch += 1

    if plot_fig == True:
        save_plot(n_epochs, 1, evalacc_all, 'pg_accuracy6.png')
        save_plot(n_epochs, 1, evalpreck_all, 'pg_map6.png')
    return inner_val_acc_best, inner_val_preck_best
        print("Testing")
        agent.load_state_dict(torch.load(pretrained_agent))
        print("Agent evaluation!")
        _ = evaluate_agent(agent, 101, bsize, recom_length - 1, validSample,
                           testSample, device, 'test')
        print("User model evaluation!")
        #generator.load_state_dict(torch.load(pretrained_gen))
        #Evaluate without EOS
        generator.load_state_dict(torch.load(pretrained_gen))
        eval_acc, eval_preck, eval_rewd, eval_loss = evaluate_user(
            generator, 101, bsize, recom_length - 1, validSample, testSample,
            loss_fn_target, loss_fn_reward, device, 'test')
        #Save the whole policy model
        torch.save(agent, 'model_output/agent.pickle')

        if interact:
            #Generate new samples from the environment
            reward_orig, reward_optim = Eval('model_output/agent.pickle')
            if e == 0:
                rewards.append(reward_orig)
            rewards.append(reward_optim)
            '''  
            #Load the best model
            generator.load_state_dict(torch.load(pretrained_gen))
            discriminator.load_state_dict(torch.load(pretrained_dis))
            agent.load_state_dict(torch.load(pretrained_agent))
            '''
            #Generate new data
            subprocess.call(subprocess_cmd, shell=False)
    save_plot(Epochs, 1, rewards, 'all_rewards.png')
    def train_epochs(self, train_corpus, dev_corpus, test_corpus, start_epoch,
                     n_epochs):
        """Trains model for n_epochs epochs"""
        for epoch in range(start_epoch, start_epoch + n_epochs):
            if not self.stop:
                print('\nTRAINING : Epoch ' + str((epoch + 1)))
                self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr'] * self.config.lr_decay \
                    if epoch > start_epoch and 'sgd' in self.config.optimizer else self.optimizer.param_groups[0]['lr']
                if 'sgd' in self.config.optimizer:
                    print('Learning rate : {0}'.format(
                        self.optimizer.param_groups[0]['lr']))
                try:
                    self.train(train_corpus, epoch + 1)
                except KeyboardInterrupt:
                    print('-' * 89)
                    print('Exiting from training early')
                # training epoch completes, now do validation
                print('\nVALIDATING : Epoch ' + str((epoch + 1)))
                dev_acc = -1
                try:
                    dev_acc = self.validate(dev_corpus)
                    self.dev_accuracies.append(dev_acc)
                    print('validation acc = %.2f%%' % dev_acc)
                except KeyboardInterrupt:
                    print('-' * 89)
                    print('Exiting from dev early')

                try:
                    test_acc = self.validate(test_corpus)
                    print('validation acc = %.2f%%' % test_acc)
                except KeyboardInterrupt:
                    print('-' * 89)
                    print('Exiting from testing early')

                # save model if dev accuracy goes up
                if self.best_dev_acc < dev_acc and dev_acc != -1:
                    self.best_dev_acc = dev_acc
                    file_path = self.config.output_base_path + self.config.task + '/' + self.config.model_file_name
                    if file_path.endswith('.pth.tar') == False:
                        file_path += 'model_best.pth.tar'

                    helper.save_checkpoint(
                        {
                            'epoch': (epoch + 1),
                            'state_dict': self.model.state_dict(),
                            'best_acc': self.best_dev_acc,
                            'optimizer': self.optimizer.state_dict()
                        }, file_path)
                    print('model saved as: ', file_path)
                    self.times_no_improvement = 0
                else:
                    if 'sgd' in self.config.optimizer:
                        self.optimizer.param_groups[0][
                            'lr'] = self.optimizer.param_groups[0][
                                'lr'] / self.config.lrshrink
                        print('Shrinking lr by : {0}. New lr = {1}'.format(
                            self.config.lrshrink,
                            self.optimizer.param_groups[0]['lr']))
                        if self.optimizer.param_groups[0][
                                'lr'] < self.config.minlr:
                            self.stop = True
                    if 'adam' in self.config.optimizer:
                        self.times_no_improvement += 1
                        # early stopping (at 'n'th decrease in accuracy)
                        if self.times_no_improvement == self.config.early_stop:
                            self.stop = True
                # save the train loss and development accuracy plot
                helper.save_plot(self.train_accuracies,
                                 self.config.output_base_path,
                                 'training_acc_plot_', epoch + 1)
                helper.save_plot(self.dev_accuracies,
                                 self.config.output_base_path, 'dev_acc_plot_',
                                 epoch + 1)
            else:
                break
Ejemplo n.º 8
0
 def train_epochs(self, start_epoch, n_epochs):
     """Trains model for n_epochs epochs"""
     for epoch in range(start_epoch, start_epoch + n_epochs):
         if not self.stop:
             print('\nTRAINING : Epoch ' + str((epoch + 1)))
             self.optimizerG.param_groups[0]['lr'] = self.optimizerG.param_groups[0]['lr'] * self.config.lr_decay \
                 if epoch > start_epoch and 'sgd' in self.config.optimizer else self.optimizerG.param_groups[0]['lr']
             if self.config.adversarial:
                 self.optimizerD.param_groups[0]['lr'] = self.optimizerD.param_groups[0]['lr'] * self.config.lr_decay \
                     if (epoch + 1) > 1 and 'sgd' in self.config.optimizer else self.optimizerD.param_groups[0]['lr']
             if 'sgd' in self.config.optimizer:
                 print('Learning rate : {0}'.format(
                     self.optimizerG.param_groups[0]['lr']))
             self.train()
             # training epoch completes, now do validation
             print('\nVALIDATING : Epoch ' + str((epoch + 1)))
             dev_acc = self.validate()
             self.dev_accuracies.append(dev_acc)
             print('validation acc = %.2f' % dev_acc)
             # save model if dev loss goes down
             if self.best_dev_acc < dev_acc:
                 self.best_dev_acc = dev_acc
                 check_point = dict()
                 check_point['epoch'] = (epoch + 1)
                 check_point['state_dict_G'] = self.generator.state_dict()
                 check_point['best_acc'] = self.best_dev_acc
                 check_point['optimizerG'] = self.optimizerG.state_dict()
                 if self.config.adversarial:
                     check_point[
                         'state_dict_D'] = self.discriminator.state_dict()
                     check_point['optimizerD'] = self.optimizerD.state_dict(
                     )
                 helper.save_checkpoint(
                     check_point,
                     self.config.save_path + 'model_best.pth.tar')
                 self.times_no_improvement = 0
             else:
                 if 'sgd' in self.config.optimizer:
                     self.optimizerG.param_groups[0][
                         'lr'] = self.optimizerG.param_groups[0][
                             'lr'] / self.config.lrshrink
                     if self.config.adversarial:
                         self.optimizerD.param_groups[0]['lr'] = self.optimizerG.param_groups[0]['lr'] / \
                                                                 self.config.lrshrink
                     print('Shrinking lr by : {0}. New lr = {1}'.format(
                         self.config.lrshrink,
                         self.optimizerG.param_groups[0]['lr']))
                     if self.optimizerG.param_groups[0][
                             'lr'] < self.config.minlr:
                         self.stop = True
                 if 'adam' in self.config.optimizer:
                     self.times_no_improvement += 1
                     # early stopping (at 3rd decrease in accuracy)
                     if self.times_no_improvement == 3:
                         self.stop = True
             # save the train and development loss plot
             helper.save_plot(self.train_accuracies, self.config.save_path,
                              'training_acc_plot_', epoch + 1)
             helper.save_plot(self.dev_accuracies, self.config.save_path,
                              'dev_acc_plot_', epoch + 1)
         else:
             break
Ejemplo n.º 9
0
def train(model, epoch_count, batch_size, z_dim, star_learning_rate, beta1, beta2, get_batches, data_shape,
          image_mode):
    input_real, input_z, lrate, k_t = model.model_inputs(
        *(data_shape[1:]), z_dim)

    d_loss, g_loss, d_real, d_fake = model.model_loss(
        input_real, input_z, data_shape[3], z_dim, k_t)

    d_opt, g_opt = model.model_opt(d_loss, g_loss, lrate, beta1, beta2)

    losses = []
    learning_rate = 0
    iter = 0

    epoch_drop = 3

    lam = 1e-3
    gamma = 0.5
    k_curr = 0.0

    test_z = np.random.uniform(-1, 1, size=(16, z_dim))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch_i in range(epoch_count):

            learning_rate = star_learning_rate * \
                math.pow(0.2, math.floor((epoch_i + 1) / epoch_drop))

            for batch_images in get_batches(batch_size):
                iter += 1

                batch_images *= 2

                batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim))

                _, d_real_curr = sess.run([d_opt, d_real], feed_dict={
                                          input_z: batch_z, input_real: batch_images, lrate: learning_rate, k_t: k_curr})

                _, d_fake_curr = sess.run([g_opt, d_fake], feed_dict={
                                          input_z: batch_z, input_real: batch_images, lrate: learning_rate, k_t: k_curr})

                k_curr = k_curr + lam * (gamma * d_real_curr - d_fake_curr)

                # save convergence measure
                if iter % 100 == 0:
                    measure = d_real_curr + \
                        np.abs(gamma * d_real_curr - d_fake_curr)
                    losses.append(measure)

                    print("Epoch {}/{}...".format(epoch_i + 1, epoch_count),
                          'Convergence measure: {:.4}'.format(measure))

                # save test and batch images
                if iter % 700 == 0:
                    helper.show_generator_output(
                        sess, model.generator, input_z, batch_z, data_shape[3], image_mode, 'batch-' + str(iter))

                    helper.show_generator_output(
                        sess, model.generator, input_z, test_z, data_shape[3], image_mode, 'test-' + str(iter))

        print('Training steps: ', iter)

        losses = np.array(losses)

        helper.save_plot([losses, helper.smooth(losses)],
                         'convergence_measure.png')