testSample, device, 'test')

        #Generate fake sequences, only use the training data
        generator.load_state_dict(torch.load(pretrained_gen))
        agent.load_state_dict(torch.load(pretrained_agent))
        trainSample, validSample, testSample = sampleSplit(
            trainindex, validindex, testindex, Seqlist, numlabel,
            recom_length - 1, 'gen')
        print('Generate sample : {0}'.format(trainSample.length()))

        shutil.copy('click_gen_real.txt', write_item)
        shutil.copy('tar_gen_real.txt', write_target)
        shutil.copy('reward_gen_real.txt', write_reward)
        shutil.copy('action_gen_real.txt', write_action)
        _ = gen_fake(generator, agent, trainSample, bsize, embed_dim, device,
                     write_item, write_target, write_reward, write_action,
                     numlabel, max_length, recom_length - 1)  #No EOS

        #Pretrain discriminator
        print('\n--------------------------------------------')
        print("Pretrain the Discriminator")
        print('--------------------------------------------')
        dis_clicklist, _ = ReadSeq(write_item, write_reward, write_action,
                                   write_target)
        trainindex_dis, validindex_dis, testindex_dis = split_index(
            0.8, 0.1, len(dis_clicklist), True)
        trainSample, validSample, testSample = sampleSplit(
            trainindex_dis, validindex_dis, testindex_dis, dis_clicklist, 2,
            recom_length - 1, 'dis')
        print('Train sample : {0}'.format(trainSample.length()))
        print('Valid sample : {0}'.format(validSample.length()))
def pgtrain(optims_gen,
            optims_dis,
            generator,
            agent,
            discriminator,
            bsize,
            embed_dim,
            trainSample,
            validSample,
            testSample,
            val_acc_best,
            val_preck_best,
            val_loss_best,
            action_num,
            max_length,
            recom_length,
            gen_ratio=0.1,
            n_epochs=5,
            write_item='click_gen.txt',
            write_target='tar_gen.txt',
            write_reward='reward_gen.txt',
            write_action='action_gen.txt',
            plot_fig=True,
            pretrain=False):
    outputdir = "model_output"
    outputmodelname = "simu.model.pth"
    lrshrink = 5
    minlr = 1e-5

    #Evaluation loss functions
    loss_fn_target = nn.CrossEntropyLoss()
    loss_fn_reward = nn.BCEWithLogitsLoss()
    loss_fn_target.size_average = True
    loss_fn_target.to(device)
    loss_fn_reward.size_average = True
    loss_fn_reward.to(device)

    inner_val_preck_best = val_preck_best
    inner_val_acc_best = val_acc_best
    inner_loss_best = val_loss_best
    epoch = 1
    eval_type = 'valid'
    g_step = 1
    d_step = 1
    evalacc_all = [val_acc_best]
    evalpreck_all = [val_preck_best]
    #Define the optimizer
    optim_fn_gen, optim_params_gen = get_optimizer(optims_gen)
    optim_fn_dis, optim_params_dis = get_optimizer(optims_dis)
    optimizer_dis = optim_fn_dis(
        filter(lambda p: p.requires_grad, discriminator.parameters()),
        **optim_params_dis)
    params_agent = list(agent.parameters())
    params_usr = list(generator.parameters())
    optimizer_agent = optim_fn_gen(
        filter(lambda p: p.requires_grad, params_agent), **optim_params_gen)
    optimizer_usr = optim_fn_gen(filter(lambda p: p.requires_grad, params_usr),
                                 **optim_params_gen)
    while epoch <= n_epochs:
        print('\nAdversarial Policy Gradient Training!')
        # Select subset of trainSample
        subnum = 8000
        for i in range(g_step):
            print('G-step')
            if pretrain:
                print('For Pretraining')
                _ = train_gen_pg_each(generator, agent,
                                      discriminator, epoch, trainSample,
                                      trainSample.length(), optimizer_agent,
                                      optimizer_usr, bsize, embed_dim,
                                      recom_length, max_length, action_num,
                                      device, 0, pretrain)
            else:
                print('For Policy Gradient Update')
                #shuffle_index=np.random.permutation(origin.length())
                _ = train_gen_pg_each(generator, agent, discriminator, epoch,
                                      trainSample, subnum, optimizer_agent,
                                      optimizer_usr, bsize, embed_dim,
                                      recom_length, max_length, action_num,
                                      device, 0.1, pretrain)

        # save model
        # Evaluate without eos, no eos input
        print("Agent evaluation!")
        eval_acc, eval_preck = evaluate_agent(agent,
                                              epoch,
                                              bsize,
                                              recom_length,
                                              validSample,
                                              testSample,
                                              device,
                                              eval_type='valid')
        print("User model evaluation!")
        _ = evaluate_user(generator, epoch, bsize, recom_length, validSample,
                          testSample, loss_fn_target, loss_fn_reward, device,
                          eval_type)
        print("Interaction evaluation!")
        _ = evaluate_interaction(
            (generator, agent), epoch, bsize, recom_length, validSample,
            testSample, loss_fn_target, loss_fn_reward, device, eval_type)

        evalacc_all.append(eval_acc)
        evalpreck_all.append(eval_preck)
        if eval_type == 'valid' and epoch <= n_epochs:
            print('saving model at epoch {0}'.format(epoch))
            if not os.path.exists(outputdir):
                os.makedirs(outputdir)
            torch.save(
                agent.state_dict(),
                os.path.join(outputdir, 'irecGan_agent3.' + outputmodelname))
            torch.save(
                generator.state_dict(),
                os.path.join(outputdir, 'irecGan_gen3.' + outputmodelname))

            inner_val_acc_best = eval_acc
            inner_val_preck_best = eval_preck

        if not pretrain:
            '''
            #Adjust the reward prediction
            print('Reward Adjust')
            trainSample_rewd, validSample_rewd, testSample_rewd=sampleSplit(trainindex, validindex, testindex, Seqlist, numlabel, recom_length)
            _ = train_user_pred(optims_dis, generator, bsize, embed_dim, recom_length + 1, trainSample_rewd, validSample_rewd, testSample_rewd, 'generator with rec', None, None, None, None, only_rewards = True, n_epochs=1)
            #Enable full model training
            for name, param in generator.named_parameters():
                if 'embedding' in name or 'encoder' or 'enc2out' in name:
                    param.requires_grad = True
            '''
            print('\nD-step')
            #Discriminator trainging
            for i in range(d_step):
                shutil.copy('click_gen_real.txt', write_item)
                shutil.copy('reward_gen_real.txt', write_reward)
                shutil.copy('tar_gen_real.txt', write_target)
                shutil.copy('action_gen_real.txt', write_action)
                _, _, _, _ = gen_fake(generator, agent, trainSample, bsize,
                                      embed_dim, device, write_item,
                                      write_target, write_reward, write_action,
                                      action_num, max_length, recom_length)
                clicklist, _ = ReadSeq(write_item, write_reward, write_action,
                                       write_target)
                trainindex_dis, validindex_dis, testindex_dis = split_index(
                    0.7, 0.1, len(clicklist), True)  #Shuffle the index
                trainSample_dis, validSample_dis, testSample_dis = sampleSplit(
                    trainindex_dis, validindex_dis, testindex_dis, clicklist,
                    2, recom_length, 'dis')

                discriminator, _, _ = train_dis(optims_dis, discriminator,
                                                bsize, embed_dim, recom_length,
                                                trainSample_dis,
                                                validSample_dis,
                                                testSample_dis)
        epoch += 1

    if plot_fig == True:
        save_plot(n_epochs, 1, evalacc_all, 'pg_accuracy6.png')
        save_plot(n_epochs, 1, evalpreck_all, 'pg_map6.png')
    return inner_val_acc_best, inner_val_preck_best
コード例 #3
0
# setup stuff
n = 0
l = []
outfile = "../data/fake_data.csv"

# basic string
for i in range(0, 10000):
    # string to list
    words = h.words()

    # create same size list of words
    label = ["word"] * len(words)

    # generate a fake
    y = h.gen_fake()
    # cleaning punct
    # may look into adding a new label for punct
    y['value'] = stem(y['value'])

    y_val = re.split("\\s+", y['value'])
    newlabel = [y['label']] * len(y_val)

    newlabel = update_label(y_val, newlabel)

    # insert into word
    idx = random.choice(range(0, len(words)))
    words[idx] = y_val
    label[idx] = newlabel

    words = flatten(words)