Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description="Pretrain Policy Network")
    parser.add_argument('-inputdim', type=int, dest='inputdim', help='input dimension', default = 29)
    parser.add_argument('-hiddendim', type=int, dest='hiddendim', help='hidden dimension', default = 64)
    parser.add_argument('-outputdim', type=int, dest='outputdim', help='output dimension', default = 12)
    parser.add_argument('-bs', type=int, dest='bs', help='batch size', default = 16)
    parser.add_argument('-optim', type=str, dest='optim', help='optimizer choice', default = 'Adam')
    parser.add_argument('-lr', type=float, dest='lr', help='learning rate', default = 0.001)
    parser.add_argument('-decay', type=float, dest='decay', help='weight decay', default = 0)
    parser.add_argument('-mod', type=str, dest='mod', help='mod', default = 'ours') # ear crm

    A = parser.parse_args()
    print('Arguments loaded!')

    PN = PolicyNetwork(input_dim=A.inputdim, dim1=A.hiddendim, output_dim=A.outputdim)

    cuda_(PN)
    print('Model on GPU')
    data_list = list()

    dir = '../data/pretrain-numpy-data-{}'.format(A.mod)

    files = os.listdir(dir)
    file_paths = [dir + '/' + f for f in files]

    i = 0
    for fp in file_paths:
        with open(fp, 'rb') as f:
            try:
                data_list += pickle.load(f)
                i += 1
            except:
                pass
    print('total files: {}'.format(i))
    data_list = data_list[: int(len(data_list))]
    print('length of data list is: {}'.format(len(data_list)))

    random.shuffle(data_list)

    train_list = data_list[: int(len(data_list) * 0.7)]
    valid_list = data_list[int(len(data_list) * 0.7): int(len(data_list) * 0.9)]
    test_list = data_list[int(len(data_list) * 0.9):]
    print('train length: {}, valid length: {}, test length: {}'.format(len(train_list), len(valid_list), len(test_list)))

    if A.optim == 'Ada':
        optimizer = torch.optim.Adagrad(PN.parameters(), lr=A.lr, weight_decay=A.decay)
    if A.optim == 'Adam':
        optimizer = torch.optim.Adam(PN.parameters(), lr=A.lr, weight_decay=A.decay)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        random.shuffle(train_list)
        model_name = '../data/PN-model-{}/pretrain-model.pt'.format(A.mod)
        train(A.bs, train_list, valid_list, test_list, optimizer, PN, criterion, epoch, model_name)
Ejemplo n.º 2
0
 def __init__(self, input_dim, output_dim, dim1, actor_lr, critic_lr,
              discount_rate, actor_w_decay, critic_w_decay):
     '''
     params:
         
     '''
     super(SAC_Net, self).__init__()
     self.actor_network = PolicyNetwork(input_dim, dim1, output_dim)
     # Create 2 critic netowrks for the purpose of debiasing value estimation
     self.critic_network = PolicyNetwork(input_dim, dim1, output_dim)
     self.critic2_network = PolicyNetwork(input_dim, dim1, output_dim)
     self.actor_optimizer = torch.optim.Adam(
         self.actor_network.parameters(),
         lr=actor_lr,
         weight_decay=actor_w_decay)
     self.critic_optimizer = torch.optim.Adam(
         self.critic_network.parameters(),
         lr=critic_lr,
         weight_decay=critic_w_decay)
     self.critic2_optimizer = torch.optim.Adam(
         self.critic2_network.parameters(),
         lr=critic_lr,
         weight_decay=critic_w_decay)
     # Create 2 target networks to stablelize training
     self.critic1_target = PolicyNetwork(input_dim, dim1, output_dim)
     self.critic2_target = PolicyNetwork(input_dim, dim1, output_dim)
     copy_model(self.critic_network, self.critic1_target)
     copy_model(self.critic2_network, self.critic2_target)
     # Define discount_rate
     self.discount_rate = discount_rate
Ejemplo n.º 3
0
# -*- coding: utf-8 -*-
import pickle
import torch
import argparse

import time
import numpy as np
import json

from pn import PolicyNetwork
import copy


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        y = m.in_features
        m.weight.data.normal_(0.0, 1 / np.sqrt(y))
        m.bias.data.fill_(0)


PN_model = PolicyNetwork(input_dim=29, dim1=64, output_dim=12)
PN_model.apply(weights_init_normal)
torch.save(PN_model.state_dict(), '../PN-model-ours/PN-model-ours.txt')
torch.save(PN_model.state_dict(), '../PN-model-ours/pretrain-model.pt')
Ejemplo n.º 4
0
def main(epoch):
    parser = argparse.ArgumentParser(
        description="Run conversational recommendation.")
    parser.add_argument('-mt', type=int, dest='mt', help='MAX_TURN', default=5)
    parser.add_argument('-playby',
                        type=str,
                        dest='playby',
                        help='playby',
                        default='policy')
    # options include:
    # AO: (Ask Only and recommend by probability)
    # RO: (Recommend Only)
    # policy: (action decided by our policy network)
    parser.add_argument('-fmCommand',
                        type=str,
                        dest='fmCommand',
                        help='fmCommand',
                        default=8)
    # the command used for FM, check out /EAR/lastfm/FM/
    parser.add_argument('-optim',
                        type=str,
                        dest='optim',
                        help='optimizer',
                        default='SGD')
    # the optimizer for policy network
    parser.add_argument('-lr', type=float, dest='lr', help='lr', default=0.001)
    # learning rate of policy network
    parser.add_argument('-decay',
                        type=float,
                        dest='decay',
                        help='decay',
                        default=0)
    # weight decay
    parser.add_argument('-TopKTaxo',
                        type=int,
                        dest='TopKTaxo',
                        help='TopKTaxo',
                        default=3)
    # how many 2-layer feature will represent a big feature. Only Yelp dataset use this param, lastFM have no effect.
    parser.add_argument('-gamma',
                        type=float,
                        dest='gamma',
                        help='gamma',
                        default=0.7)
    # gamma of training policy network
    parser.add_argument('-trick',
                        type=int,
                        dest='trick',
                        help='trick',
                        default=0)
    # whether use normalization in training policy network
    parser.add_argument('-startFrom',
                        type=int,
                        dest='startFrom',
                        help='startFrom',
                        default=0)  # 85817
    # startFrom which user-item interaction pair
    parser.add_argument('-endAt',
                        type=int,
                        dest='endAt',
                        help='endAt',
                        default=20000)
    # endAt which user-item interaction pair
    parser.add_argument('-strategy',
                        type=str,
                        dest='strategy',
                        help='strategy',
                        default='maxent')
    # strategy to choose question to ask, only have effect
    parser.add_argument('-eval', type=int, dest='eval', help='eval', default=0)
    # whether current run is for evaluation
    parser.add_argument('-mini', type=int, dest='mini', help='mini', default=0)
    # means `mini`-batch update the FM
    parser.add_argument('-alwaysupdate',
                        type=int,
                        dest='alwaysupdate',
                        help='alwaysupdate',
                        default=0)
    # means always mini-batch update the FM, alternative is that only do the update for 1 time in a session.
    # we leave this exploration tof follower of our work.
    parser.add_argument('-initeval',
                        type=int,
                        dest='initeval',
                        help='initeval',
                        default=0)
    # whether do the evaluation for the `init`ial version of policy network (directly after pre-train,default=)
    parser.add_argument('-upoptim',
                        type=str,
                        dest='upoptim',
                        help='upoptim',
                        default='SGD')
    # optimizer for reflection stafe
    parser.add_argument('-upcount',
                        type=int,
                        dest='upcount',
                        help='upcount',
                        default=0)
    # how many times to do reflection
    parser.add_argument('-upreg',
                        type=float,
                        dest='upreg',
                        help='upreg',
                        default=0.001)
    # regularization term in
    parser.add_argument('-code',
                        type=str,
                        dest='code',
                        help='code',
                        default='stable')
    # We use it to give each run a unique identifier.
    parser.add_argument('-purpose',
                        type=str,
                        dest='purpose',
                        help='purpose',
                        default='train')
    # options: pretrain, others
    parser.add_argument('-mod',
                        type=str,
                        dest='mod',
                        help='mod',
                        default='ear')
    # options: CRM, EAR
    parser.add_argument('-mask', type=int, dest='mask', help='mask', default=0)
    # use for ablation study, 1, 2, 3, 4 represent our four segments, {ent, sim, his, len}

    A = parser.parse_args()

    cfg.change_param(playby=A.playby,
                     eval=A.eval,
                     update_count=A.upcount,
                     update_reg=A.upreg,
                     purpose=A.purpose,
                     mod=A.mod,
                     mask=A.mask)

    random.seed(1)

    # we random shuffle and split the valid and test set, for Action Stage training and evaluation respectively, to avoid the bias in the dataset.
    all_list = cfg.valid_list + cfg.test_list
    print('The length of all list is: {}'.format(len(all_list)))
    random.shuffle(all_list)
    the_valid_list = all_list[:int(len(all_list) / 2.0)]
    the_test_list = all_list[int(len(all_list) / 2.0):]
    # the_valid_list = cfg.valid_list
    # the_test_list = cfg.test_list

    gamma = A.gamma
    FM_model = cfg.FM_model

    if A.mod == 'ear':
        # fp = '../../data/PN-model-crm/PN-model-crm.txt'
        if epoch == 0 and cfg.eval == 0:
            fp = '../../data/PN-model-ear/pretrain-model.pt'
        else:
            fp = '../../data/PN-model-ear/model-epoch0'
    INPUT_DIM = 0
    if A.mod == 'ear':
        INPUT_DIM = len(cfg.tag_map) * 2 + cfg.MAX_TURN + 8
    if A.mod == 'crm':
        INPUT_DIM = 4382
    PN_model = PolicyNetwork(input_dim=INPUT_DIM,
                             dim1=1500,
                             output_dim=len(cfg.tag_map) + 1)
    start = time.time()

    try:
        print('fp is: {}'.format(fp))
        PN_model.load_state_dict(torch.load(fp))
        print('load PN model success. ')
    except:
        print('Cannot load the model!!!!!!!!!\nfp is: {}'.format(fp))
        # if A.playby == 'policy':
        #     sys.exit()

    if A.optim == 'Adam':
        optimizer = torch.optim.Adam(PN_model.parameters(),
                                     lr=A.lr,
                                     weight_decay=A.decay)
    if A.optim == 'SGD':
        optimizer = torch.optim.SGD(PN_model.parameters(),
                                    lr=A.lr,
                                    weight_decay=A.decay)
    if A.optim == 'RMS':
        optimizer = torch.optim.RMSprop(PN_model.parameters(),
                                        lr=A.lr,
                                        weight_decay=A.decay)

    numpy_list = list()
    NUMPY_COUNT = 0

    sample_dict = defaultdict(list)
    conversation_length_list = list()
    # endAt = len(the_valid_list) if cfg.eval == 0 else len(the_test_list)
    endAt = len(the_valid_list)
    print(f'endAt: {endAt}')
    print('-' * 10)
    print('Train mode' if cfg.eval == 0 else 'Test mode')
    print('-' * 10)
    for epi_count in range(A.startFrom, endAt):
        if epi_count % 1 == 0:
            print('-----\nEpoch: {}\tIt has processed {}/{} episodes'.format(
                epoch, epi_count, endAt))
        start = time.time()
        cfg.actionProb = epi_count / endAt

        # if A.test == 1 or A.eval == 1:
        if A.eval == 1:
            # u, item = the_test_list[epi_count]
            u, item = the_valid_list[epi_count]
        else:
            u, item = the_valid_list[epi_count]

        if A.purpose == 'fmdata':
            u, item = 0, epi_count

        if A.purpose == 'pretrain':
            u, item = cfg.train_list[epi_count]

        current_FM_model = copy.deepcopy(FM_model)
        param1, param2 = list(), list()
        param3 = list()
        param4 = list()
        i = 0
        for name, param in current_FM_model.named_parameters():
            param4.append(param)
            # print(name, param)
            if i == 0:
                param1.append(param)
            else:
                param2.append(param)
            if i == 2:
                param3.append(param)
            i += 1
        optimizer1_fm = torch.optim.Adagrad(param1,
                                            lr=0.01,
                                            weight_decay=A.decay)
        optimizer2_fm = torch.optim.SGD(param4, lr=0.001, weight_decay=A.decay)

        user_id = int(u)
        item_id = int(item)

        write_fp = '../../data/interaction-log/{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
            A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
            A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask)

        choose_pool = cfg.item_dict[str(item_id)]['categories']

        if A.purpose not in ['pretrain', 'fmdata']:
            # this means that: we are not collecting data for pretraining or fm data
            # then we only randomly choose one start attribute to ask!
            choose_pool = [random.choice(choose_pool)]
        # for item_id in cfg.item_dict:
        #     choose_pool = [k for k in cfg.item_dict_rel[str(item_id)] if len(cfg.item_dict_rel[str(item_id)][k]) != 0]
        #     if choose_pool == None:
        #         print(item_id)
        print(f'user id: {user_id}\titem id: {item_id}')
        # choose_pool = [k for k in cfg.item_dict_rel[str(item_id)] if len(cfg.item_dict_rel[str(item_id)][k]) != 0]
        # choose_pool = random.choice(choose_pool)
        for c in choose_pool:
            with open(write_fp, 'a+') as f:
                f.write(
                    'Starting new\nuser ID: {}, item ID: {} episode count: {}, feature: {}\n'
                    .format(user_id, item_id, epi_count,
                            cfg.item_dict[str(item_id)]['categories']))
            start_facet = c
            if A.purpose != 'pretrain':
                log_prob_list, rewards, hl = run_one_episode(
                    current_FM_model, user_id, item_id, A.mt, False, write_fp,
                    A.strategy, A.TopKTaxo, PN_model, gamma, A.trick, A.mini,
                    optimizer1_fm, optimizer2_fm, A.alwaysupdate, start_facet,
                    A.mask, sample_dict)
                if cfg.eval == 0:
                    with open(f'../../data/train_rec{epoch}.txt', 'a+') as f:
                        # f.writelines(str(rewards.tolist()))
                        f.writelines(str(hl))
                        f.writelines('\n')
                else:
                    with open('../../data/test_rec.txt', 'a+') as f:
                        # f.writelines(str(rewards.tolist()))
                        f.writelines(str(hl))
                        f.writelines('\n')
            else:
                current_np = run_one_episode(
                    current_FM_model, user_id, item_id, A.mt, False, write_fp,
                    A.strategy, A.TopKTaxo, PN_model, gamma, A.trick, A.mini,
                    optimizer1_fm, optimizer2_fm, A.alwaysupdate, start_facet,
                    A.mask, sample_dict)
                numpy_list += current_np

            # update PN model
            if A.playby == 'policy' and A.eval != 1:
                update_PN_model(PN_model, log_prob_list, rewards, optimizer)
                print('updated PN model')
                current_length = len(log_prob_list)
                conversation_length_list.append(current_length)
            # end update

            if A.purpose != 'pretrain':
                with open(write_fp, 'a') as f:
                    f.write('Big features are: {}\n'.format(choose_pool))
                    if rewards is not None:
                        f.write('reward is: {}\n'.format(
                            rewards.data.numpy().tolist()))
                    f.write('WHOLE PROCESS TAKES: {} SECONDS\n'.format(
                        time.time() - start))

        # Write to pretrain numpy.
        if A.purpose == 'pretrain':
            if len(numpy_list) > 5000:
                with open(
                        '../../data/pretrain-numpy-data-{}/segment-{}-start-{}-end-{}.pk'
                        .format(A.mod, NUMPY_COUNT, A.startFrom,
                                A.endAt), 'wb') as f:
                    pickle.dump(numpy_list, f)
                    print('Have written 5000 numpy arrays!')
                NUMPY_COUNT += 1
                numpy_list = list()
        # numpy_list is a list of list.
        # e.g. numpy_list[0][0]: int, indicating the action.
        # numpy_list[0][1]: a one-d array of length 89 for EAR, and 33 for CRM.
        # end write

        # Write sample dict:
        if A.purpose == 'fmdata' and A.playby != 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/start-{}-end-{}.json'.format(
                            A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)
        # end write
        if A.purpose == 'fmdata' and A.playby == 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/valid-start-{}-end-{}.json'.
                        format(A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)

        check_span = 500
        if epi_count % check_span == 0 and epi_count >= 3 * check_span and cfg.eval != 1 and A.purpose != 'pretrain':
            # We use AT (average turn of conversation) as our stopping criterion
            # in training mode, save RL model periodically
            # save model first
            # PATH = '../../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            #     A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma, A.playby, A.strategy, A.TopKTaxo, A.trick,
            #     A.eval, A.initeval,
            #     A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask, epi_count)
            PATH = f'../../data/PN-model-{A.mod}/model-epoch{epoch}'
            torch.save(PN_model.state_dict(), PATH)
            print('Model saved at {}'.format(PATH))

            # a0 = conversation_length_list[epi_count - 4 * check_span: epi_count - 3 * check_span]
            a1 = conversation_length_list[epi_count -
                                          3 * check_span:epi_count -
                                          2 * check_span]
            a2 = conversation_length_list[epi_count -
                                          2 * check_span:epi_count -
                                          1 * check_span]
            a3 = conversation_length_list[epi_count - 1 * check_span:]
            a1 = np.mean(np.array(a1))
            a2 = np.mean(np.array(a2))
            a3 = np.mean(np.array(a3))

            with open(write_fp, 'a') as f:
                f.write('$$$current turn: {}, a3: {}, a2: {}, a1: {}\n'.format(
                    epi_count, a3, a2, a1))
            print('current turn: {}, a3: {}, a2: {}, a1: {}'.format(
                epi_count, a3, a2, a1))

            num_interval = int(epi_count / check_span)
            for i in range(num_interval):
                ave = np.mean(
                    np.array(conversation_length_list[i * check_span:(i + 1) *
                                                      check_span]))
                print('start: {}, end: {}, average: {}'.format(
                    i * check_span, (i + 1) * check_span, ave))
                # PATH = '../../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
                #     A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma, A.playby, A.strategy, A.TopKTaxo,
                #     A.trick,
                #     A.eval, A.initeval,
                #     A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask, (i + 1) * check_span)
                PATH = f'../../data/PN-model-{A.mod}/model-epoch{epoch}'
                print('Model saved at: {}'.format(PATH))

            if a3 > a1 and a3 > a2:
                print('Early stop of RL!')
                if cfg.eval == 1:
                    exit()
                else:
                    return
Ejemplo n.º 5
0
class SAC_Net(nn.Module):
    def __init__(self, input_dim, output_dim, dim1, actor_lr, critic_lr,
                 discount_rate, actor_w_decay, critic_w_decay):
        '''
        params:
            
        '''
        super(SAC_Net, self).__init__()
        self.actor_network = PolicyNetwork(input_dim, dim1, output_dim)
        # Create 2 critic netowrks for the purpose of debiasing value estimation
        self.critic_network = PolicyNetwork(input_dim, dim1, output_dim)
        self.critic2_network = PolicyNetwork(input_dim, dim1, output_dim)
        self.actor_optimizer = torch.optim.Adam(
            self.actor_network.parameters(),
            lr=actor_lr,
            weight_decay=actor_w_decay)
        self.critic_optimizer = torch.optim.Adam(
            self.critic_network.parameters(),
            lr=critic_lr,
            weight_decay=critic_w_decay)
        self.critic2_optimizer = torch.optim.Adam(
            self.critic2_network.parameters(),
            lr=critic_lr,
            weight_decay=critic_w_decay)
        # Create 2 target networks to stablelize training
        self.critic1_target = PolicyNetwork(input_dim, dim1, output_dim)
        self.critic2_target = PolicyNetwork(input_dim, dim1, output_dim)
        copy_model(self.critic_network, self.critic1_target)
        copy_model(self.critic2_network, self.critic2_target)
        # Define discount_rate
        self.discount_rate = discount_rate

    def produce_action_info(self, state):
        action_probs = F.softmax(self.actor_network(state), dim=-1)
        greedy_action = torch.argmax(action_probs, dim=-1)
        action_distribution = torch.distributions.Categorical(action_probs)
        action = action_distribution.sample().cpu()
        # When action probs has 0 'probabilities'
        z = action_probs == 0.0
        z = z.float() * 1e-8
        log_action_probs = torch.log(action_probs + z)
        return action, action_probs, log_action_probs, greedy_action

    def calc_acttor_loss(self, batch_states, batch_action):
        criterion = nn.CrossEntropyLoss()
        _, action_probs, log_action_probs, _ = self.produce_action_info(
            batch_states)
        Vpi_1 = self.critic_network(batch_states)
        Vpi_2 = self.critic2_network(batch_states)
        # Target is set to the minimum of value functions to reduce bias
        min_V = torch.min(Vpi_1, Vpi_2)
        policy_loss = (action_probs *
                       (self.discount_rate * log_action_probs - min_V)).sum(
                           dim=1).mean()
        batch_action = cuda_(torch.from_numpy(batch_action).long())
        policy_loss += 0.5 * criterion(action_probs, batch_action)
        log_action_probs = torch.sum(log_action_probs * action_probs, dim=1)
        return policy_loss, log_action_probs

    def calc_critic_loss(self, batch_states, batch_next_states, batch_action,
                         batch_rewards):
        batch_action = cuda_(torch.LongTensor(batch_action)).reshape(-1, 1)
        with torch.no_grad():
            next_state_action, action_probs, log_action_probs, _ = self.produce_action_info(
                batch_next_states)
            target1 = self.critic1_target(batch_next_states)
            target2 = self.critic2_target(batch_next_states)
            min_next_target = action_probs * (torch.min(
                target1, target2) - self.discount_rate * log_action_probs)
            min_next_target = min_next_target.sum(dim=1).unsqueeze(-1)
            next_q = batch_rewards + self.discount_rate * min_next_target
        qf1 = self.critic_network(batch_states).gather(1, batch_action)
        qf2 = self.critic2_network(batch_states).gather(1, batch_action)
        next_q = next_q.max(1)[0].unsqueeze(-1)
        qf1_loss = F.mse_loss(qf1, next_q)
        qf2_loss = F.mse_loss(qf2, next_q)
        return qf1_loss, qf2_loss
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        description="Run conversational recommendation.")
    parser.add_argument('-mt',
                        type=int,
                        dest='mt',
                        help='MAX_TURN',
                        default=10)
    parser.add_argument('-playby',
                        type=str,
                        dest='playby',
                        help='playby',
                        default='policy')
    parser.add_argument('-optim',
                        type=str,
                        dest='optim',
                        help='optimizer',
                        default='SGD')
    parser.add_argument('-lr', type=float, dest='lr', help='lr', default=0.001)
    parser.add_argument('-decay',
                        type=float,
                        dest='decay',
                        help='decay',
                        default=0)
    parser.add_argument('-TopKTaxo',
                        type=int,
                        dest='TopKTaxo',
                        help='TopKTaxo',
                        default=3)
    parser.add_argument('-gamma',
                        type=float,
                        dest='gamma',
                        help='gamma',
                        default=0)
    parser.add_argument('-trick',
                        type=int,
                        dest='trick',
                        help='trick',
                        default=0)
    parser.add_argument('-startFrom',
                        type=int,
                        dest='startFrom',
                        help='startFrom',
                        default=0)
    parser.add_argument('-endAt',
                        type=int,
                        dest='endAt',
                        help='endAt',
                        default=171)  #test 171
    parser.add_argument('-strategy',
                        type=str,
                        dest='strategy',
                        help='strategy',
                        default='maxsim')
    parser.add_argument('-eval', type=int, dest='eval', help='eval', default=1)
    parser.add_argument('-mini', type=int, dest='mini', help='mini', default=1)
    parser.add_argument('-alwaysupdate',
                        type=int,
                        dest='alwaysupdate',
                        help='alwaysupdate',
                        default=1)
    parser.add_argument('-initeval',
                        type=int,
                        dest='initeval',
                        help='initeval',
                        default=0)
    parser.add_argument('-upoptim',
                        type=str,
                        dest='upoptim',
                        help='upoptim',
                        default='SGD')
    parser.add_argument('-upcount',
                        type=int,
                        dest='upcount',
                        help='upcount',
                        default=4)
    parser.add_argument('-upreg',
                        type=float,
                        dest='upreg',
                        help='upreg',
                        default=0.001)
    parser.add_argument('-code',
                        type=str,
                        dest='code',
                        help='code',
                        default='stable')
    parser.add_argument('-purpose',
                        type=str,
                        dest='purpose',
                        help='purpose',
                        default='train')
    parser.add_argument('-mod',
                        type=str,
                        dest='mod',
                        help='mod',
                        default='ours')
    parser.add_argument('-mask', type=int, dest='mask', help='mask', default=0)
    # use for ablation study

    A = parser.parse_args()

    cfg.change_param(playby=A.playby,
                     eval=A.eval,
                     update_count=A.upcount,
                     update_reg=A.upreg,
                     purpose=A.purpose,
                     mod=A.mod,
                     mask=A.mask)
    device = torch.device('cuda')
    random.seed(1)
    #    random.shuffle(cfg.valid_list)
    #    random.shuffle(cfg.test_list)
    the_valid_list_item = copy.copy(cfg.valid_list_item)
    the_valid_list_features = copy.copy(cfg.valid_list_features)
    the_valid_list_location = copy.copy(cfg.valid_list_location)

    the_test_list_item = copy.copy(cfg.test_list_item)
    the_test_list_features = copy.copy(cfg.test_list_features)
    the_test_list_location = copy.copy(cfg.test_list_location)

    the_train_list_item = copy.copy(cfg.train_list_item)
    the_train_list_features = copy.copy(cfg.train_list_features)
    the_train_list_location = copy.copy(cfg.train_list_location)
    #    random.shuffle(the_valid_list)
    #    random.shuffle(the_test_list)

    gamma = A.gamma
    transE_model = cfg.transE_model
    if A.eval == 1:

        if A.mod == 'ours':
            fp = '../data/PN-model-ours/PN-model-ours.txt'
        if A.initeval == 1:

            if A.mod == 'ours':
                fp = '../data/PN-model-ours/pretrain-model.pt'
    else:
        # means training
        if A.mod == 'ours':
            fp = '../data/PN-model-ours/pretrain-model.pt'

    INPUT_DIM = 0

    if A.mod == 'ours':
        INPUT_DIM = 29  #11+10+8
    PN_model = PolicyNetwork(input_dim=INPUT_DIM, dim1=64, output_dim=12)
    start = time.time()

    try:
        PN_model.load_state_dict(torch.load(fp))
        print('Now Load PN pretrain from {}, takes {} seconds.'.format(
            fp,
            time.time() - start))
    except:
        print('Cannot load the model!!!!!!!!!\n fp is: {}'.format(fp))
        if cfg.play_by == 'policy':
            sys.exit()

    if A.optim == 'Adam':
        optimizer = torch.optim.Adam(PN_model.parameters(),
                                     lr=A.lr,
                                     weight_decay=A.decay)
    if A.optim == 'SGD':
        optimizer = torch.optim.SGD(PN_model.parameters(),
                                    lr=A.lr,
                                    weight_decay=A.decay)
    if A.optim == 'RMS':
        optimizer = torch.optim.RMSprop(PN_model.parameters(),
                                        lr=A.lr,
                                        weight_decay=A.decay)

    numpy_list = list()
    NUMPY_COUNT = 0

    sample_dict = defaultdict(list)
    conversation_length_list = list()

    combined_num = 0
    total_turn = 0
    # start episode
    for epi_count in range(A.startFrom, A.endAt):
        if epi_count % 1 == 0:
            print('-----\nIt has processed {} episodes'.format(epi_count))

        start = time.time()

        current_transE_model = copy.deepcopy(transE_model)
        current_transE_model.to(device)

        param1, param2 = list(), list()
        i = 0
        for name, param in current_transE_model.named_parameters():
            if i == 0 or i == 1:
                param1.append(param)
                # param1: head, tail
            else:
                param2.append(param)
                # param2: time, category, cluster, type
            i += 1
        '''change to transE embedding'''
        optimizer1_transE, optimizer2_transE = None, None
        if A.purpose != 'fmdata':
            optimizer1_transE = torch.optim.Adagrad(param1,
                                                    lr=0.01,
                                                    weight_decay=A.decay)
            if A.upoptim == 'Ada':
                optimizer2_transE = torch.optim.Adagrad(param2,
                                                        lr=0.01,
                                                        weight_decay=A.decay)
            if A.upoptim == 'SGD':
                optimizer2_transE = torch.optim.SGD(param2,
                                                    lr=0.001,
                                                    weight_decay=A.decay)

        if A.purpose != 'pretrain':
            items = the_valid_list_item[epi_count]  #0 18 10 3
            features = the_valid_list_features[
                epi_count]  #3,21,2,1    21,12,2,1   22,7,2,1
            location = the_valid_list_location[epi_count]
            item_list = items.strip().split(' ')
            u = item_list[0]
            item = item_list[-1]
            if A.eval == 1:
                #                u, item, l = the_test_list_item[epi_count]
                items = the_test_list_item[epi_count]  #0 18 10 3
                features = the_test_list_features[
                    epi_count]  #3,21,2,1    21,12,2,1   22,7,2,1
                location = the_test_list_location[epi_count]
                item_list = items.strip().split(' ')
                u = item_list[0]
                item = item_list[-1]

            user_id = int(u)
            item_id = int(item)
            location_id = int(location)
        else:
            user_id = 0
            item_id = epi_count

        if A.purpose == 'pretrain':
            items = the_train_list_item[epi_count]  #0 18 10 3
            features = the_train_list_features[
                epi_count]  #3,21,2,1    21,12,2,1   22,7,2,1
            location = the_train_list_location[epi_count]
            item_list = items.strip().split(' ')
            u = item_list[0]
            item = item_list[-1]
            user_id = int(u)
            item_id = int(item)
            location_id = int(location)
        print("----target item: ", item_id)
        big_feature_list = list()
        '''update L2.json'''
        for k, v in cfg.taxo_dict.items():
            #            print (k,v)
            if len(
                    set(v).intersection(
                        set(cfg.item_dict[str(item_id)]
                            ['L2_Category_name']))) > 0:
                #                print(user_id, item_id) #433,122
                #                print (k)
                big_feature_list.append(k)

        write_fp = '../data/interaction-log/{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
            A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
            A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask)
        '''care the sequence of facet pool items'''
        if cfg.item_dict[str(item_id)]['POI_Type'] is not None:
            choose_pool = ['clusters', 'POI_Type'] + big_feature_list

        choose_pool_original = choose_pool

        if A.purpose not in ['pretrain', 'fmdata']:
            choose_pool = [random.choice(choose_pool)]

        # run the episode
        for c in choose_pool:
            start_facet = c
            with open(write_fp, 'a') as f:
                f.write(
                    'Starting new\nuser ID: {}, item ID: {} episode count: {}\n'
                    .format(user_id, item_id, epi_count))
            if A.purpose != 'pretrain':
                log_prob_list, rewards, success, turn_count, known_feature_category = run_one_episode(
                    current_transE_model, user_id, item_id, A.mt, False,
                    write_fp, A.strategy, A.TopKTaxo, PN_model, gamma, A.trick,
                    A.mini, optimizer1_transE, optimizer2_transE,
                    A.alwaysupdate, start_facet, A.mask, sample_dict,
                    choose_pool_original, features, items)
            else:
                current_np = run_one_episode(
                    current_transE_model, user_id, item_id, A.mt, False,
                    write_fp, A.strategy, A.TopKTaxo, PN_model, gamma, A.trick,
                    A.mini, optimizer1_transE, optimizer2_transE,
                    A.alwaysupdate, start_facet, A.mask, sample_dict,
                    choose_pool_original, features, items)
                numpy_list += current_np
        # end run

#         check POI type, recommend location id by star, check the success.
        if A.purpose != 'pretrain':
            total_turn += turn_count

        if A.purpose != 'pretrain':
            if success == True:
                if cfg.poi_dict[str(item_id)]['POI'] == 'Combined':
                    combined_num += 1
                    L2_Category_name_list = cfg.poi_dict[str(
                        item_id)]['L2_Category_name']
                    category_list = known_feature_category
                    #                    print (category_list)
                    possible_L2_list = []
                    for i in range(len(L2_Category_name_list)):
                        for j in range(len(category_list)):
                            if L2_Category_name_list[i] == category_list[j]:
                                temp_location = cfg.poi_dict[str(
                                    item_id)]['Location_id'][i]
                                if temp_location not in possible_L2_list:
                                    possible_L2_list.append(temp_location)
                    location_list = []
                    location_list = possible_L2_list
                    random_location_list = location_list.copy()
                    star_location_list = location_list.copy()
                    if location_id not in location_list:
                        #                        random_location_list += random.sample(cfg.poi_dict[str(item_id)]['Location_id'], int(len(cfg.poi_dict[str(item_id)]['Location_id'])/2))
                        #Can test random selection.
                        random_location_list = np.random.choice(
                            category_list,
                            int(len(category_list) / 2),
                            replace=False)
                        star_list = cfg.poi_dict[str(item_id)]['stars']
                        location_list = cfg.poi_dict[str(
                            item_id)]['Location_id']
                        prob_star_list = [
                            float(i) / sum(star_list) for i in star_list
                        ]
                        if len(location_list) > 5:
                            random_location_list = np.random.choice(
                                location_list,
                                len(location_list) / 2,
                                p=prob_star_list,
                                replace=False)
                        else:
                            random_location_list = location_list
                        star_list = cfg.poi_dict[str(item_id)]['stars']

                        star_list = [
                            b[0] for b in sorted(enumerate(star_list),
                                                 key=lambda i: i[1])
                        ]
                        print(star_list)
                        location_stars = star_list[:3]
                        for index in location_stars:
                            l = cfg.poi_dict[str(
                                item_id)]['Location_id'][index]
                            star_location_list.append(l)
                    print('star_location_list: ', star_location_list)
                    print('random_location_list', random_location_list)
                    if location_id in random_location_list:
                        print('Random Combined Rec Success! in episode: {}.'.
                              format(epi_count))
                        success_at_turn_list_location_random[turn_count] += 1
                    else:
                        print('Random Combined Rec failed! in episode: {}.'.
                              format(epi_count))

                    if location_id in star_location_list:
                        print('Random Combined Rec Success! in episode: {}.'.
                              format(epi_count))
                        success_at_turn_list_location_rate[turn_count] += 1
                    else:
                        print('Random Combined Rec failed! in episode: {}.'.
                              format(epi_count))

                else:
                    print('Independent Rec Success! in episode: {}.'.format(
                        epi_count))
                    success_at_turn_list_location_random[turn_count] += 1
                    success_at_turn_list_location_rate[turn_count] += 1

        if A.purpose != 'pretrain':
            if success == True:
                print('Rec Success! in episode: {}.'.format(epi_count))
                success_at_turn_list_item[turn_count] += 1


#
        print('total combined: ', combined_num)
        # update PN model
        if A.playby == 'policy' and A.eval != 1 and A.purpose != 'pretrain':
            update_PN_model(PN_model, log_prob_list, rewards, optimizer)
            print('updated PN model')
            current_length = len(log_prob_list)
            conversation_length_list.append(current_length)
        # end update

        check_span = 10
        if epi_count % check_span == 0 and epi_count >= 3 * check_span and cfg.eval != 1 and A.purpose != 'pretrain':
            PATH = '../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
                A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
                A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask, epi_count)
            torch.save(PN_model.state_dict(), PATH)
            print('Model saved at {}'.format(PATH))

            a1 = conversation_length_list[epi_count -
                                          3 * check_span:epi_count -
                                          2 * check_span]
            a2 = conversation_length_list[epi_count -
                                          2 * check_span:epi_count -
                                          1 * check_span]
            a3 = conversation_length_list[epi_count - 1 * check_span:]
            a1 = np.mean(np.array(a1))
            a2 = np.mean(np.array(a2))
            a3 = np.mean(np.array(a3))

            with open(write_fp, 'a') as f:
                f.write('$$$current turn: {}, a3: {}, a2: {}, a1: {}\n'.format(
                    epi_count, a3, a2, a1))
            print('current turn: {}, a3: {}, a2: {}, a1: {}'.format(
                epi_count, a3, a2, a1))

            num_interval = int(epi_count / check_span)
            for i in range(num_interval):
                ave = np.mean(
                    np.array(conversation_length_list[i * check_span:(i + 1) *
                                                      check_span]))
                print('start: {}, end: {}, average: {}'.format(
                    i * check_span, (i + 1) * check_span, ave))
                PATH = '../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                    A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
                    A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval,
                    A.initeval, A.mini, A.alwaysupdate, A.upcount, A.upreg,
                    A.mask, (i + 1) * check_span)
                print('Model saved at: {}'.format(PATH))

            if a3 > a1 and a3 > a2:
                print('Early stop of RL!')
                exit()

        # write control information
        if A.purpose != 'pretrain':
            with open(write_fp, 'a') as f:
                f.write('Big features are: {}\n'.format(choose_pool))
                if rewards is not None:
                    f.write('reward is: {}\n'.format(
                        rewards.data.numpy().tolist()))
                f.write(
                    'WHOLE PROCESS TAKES: {} SECONDS\n'.format(time.time() -
                                                               start))
        # end write

        # Write to pretrain numpy which is the pretrain data.
        if A.purpose == 'pretrain':
            if len(numpy_list) > 5000:
                with open(
                        '../data/pretrain-numpy-data-{}/segment-{}-start-{}-end-{}.pk'
                        .format(A.mod, NUMPY_COUNT, A.startFrom,
                                A.endAt), 'wb') as f:
                    pickle.dump(numpy_list, f)
                    print('Have written 5000 numpy arrays!')
                NUMPY_COUNT += 1
                numpy_list = list()
        # end write

    print('\nLocation result:')
    sum_sr = 0.0
    for i in range(len(success_at_turn_list_location_rate)):
        success_rate = success_at_turn_list_location_rate[i] / A.endAt
        sum_sr += success_rate
        print('success rate is {} at turn {}, accumulated sum is {}'.format(
            success_rate, i + 1, sum_sr))
    print('Average turn is: ', total_turn / A.endAt + 1)
Ejemplo n.º 7
0
def main(epoch):
    parser = argparse.ArgumentParser(description="Run conversational recommendation.")
    parser.add_argument('-mt', type=int, dest='mt', help='MAX_TURN', default=5)
    parser.add_argument('-playby', type=str, dest='playby', help='playby', default='policy')
    # options include:
    # AO: (Ask Only and recommend by probability)
    # RO: (Recommend Only)
    # policy: (action decided by our policy network)
    parser.add_argument('-fmCommand', type=str, dest='fmCommand', help='fmCommand', default=8)
    # the command used for FM, check out /EAR/lastfm/FM/
    parser.add_argument('-optim', type=str, dest='optim', help='optimizer', default='SGD')
    # the optimizer for policy network
    parser.add_argument('-lr', type=float, dest='lr', help='lr', default=0.001)
    # learning rate of policy network
    parser.add_argument('-decay', type=float, dest='decay', help='decay', default=0)
    # weight decay
    parser.add_argument('-TopKTaxo', type=int, dest='TopKTaxo', help='TopKTaxo', default=3)
    # how many 2-layer feature will represent a big feature. Only Yelp dataset use this param, lastFM have no effect.
    parser.add_argument('-gamma', type=float, dest='gamma', help='gamma', default=0.7)
    # gamma of training policy network
    parser.add_argument('-trick', type=int, dest='trick', help='trick', default=0)
    # whether use normalization in training policy network
    parser.add_argument('-startFrom', type=int, dest='startFrom', help='startFrom', default=0)  # 85817
    # startFrom which user-item interaction pair
    parser.add_argument('-endAt', type=int, dest='endAt', help='endAt', default=20000)
    # endAt which user-item interaction pair
    parser.add_argument('-strategy', type=str, dest='strategy', help='strategy', default='maxent')
    # strategy to choose question to ask, only have effect
    parser.add_argument('-eval', type=int, dest='eval', help='eval', default=1)
    # whether current run is for evaluation
    parser.add_argument('-mini', type=int, dest='mini', help='mini', default=0)
    # means `mini`-batch update the FM
    parser.add_argument('-alwaysupdate', type=int, dest='alwaysupdate', help='alwaysupdate', default=0)
    # means always mini-batch update the FM, alternative is that only do the update for 1 time in a session.
    # we leave this exploration tof follower of our work.
    parser.add_argument('-initeval', type=int, dest='initeval', help='initeval', default=0)
    # whether do the evaluation for the `init`ial version of policy network (directly after pre-train,default=)
    parser.add_argument('-upoptim', type=str, dest='upoptim', help='upoptim', default='SGD')
    # optimizer for reflection stafe
    parser.add_argument('-upcount', type=int, dest='upcount', help='upcount', default=0)
    # how many times to do reflection
    parser.add_argument('-upreg', type=float, dest='upreg', help='upreg', default=0.001)
    # regularization term in
    parser.add_argument('-code', type=str, dest='code', help='code', default='stable')
    # We use it to give each run a unique identifier.
    parser.add_argument('-purpose', type=str, dest='purpose', help='purpose', default='train')
    # options: pretrain, others
    parser.add_argument('-mod', type=str, dest='mod', help='mod', default='ear')
    # options: CRM, EAR
    parser.add_argument('-mask', type=int, dest='mask', help='mask', default=0)
    # use for ablation study, 1, 2, 3, 4 represent our four segments, {ent, sim, his, len}

    A = parser.parse_args()

    cfg.change_param(playby=A.playby, eval=A.eval, update_count=A.upcount, update_reg=A.upreg, purpose=A.purpose,
                     mod=A.mod, mask=A.mask)

    random.seed(1)

    # we random shuffle and split the valid and test set, for Action Stage training and evaluation respectively, to avoid the bias in the dataset.
    # all_list = cfg.valid_list + cfg.test_list
    # print('The length of all list is: {}'.format(len(all_list)))
    # random.shuffle(all_list)
    # the_valid_list = all_list[: int(len(all_list) / 2.0)]
    # the_test_list = all_list[int(len(all_list) / 2.0):]
    the_valid_list = cfg.valid_list
    the_test_list = cfg.test_list

    gamma = A.gamma
    FM_model = cfg.FM_model

    if A.mod == 'ear':
        # fp = '../../data/PN-model-crm/PN-model-crm.txt'
        if epoch == 0 and cfg.eval == 0:
            fp = '../../data/PN-model-ear/pretrain-model.pt'
        else:
            fp = '../../data/PN-model-ear/model-epoch0'
    INPUT_DIM = 0
    if A.mod == 'ear':
        INPUT_DIM = 4667
    if A.mod == 'crm':
        INPUT_DIM = 4382
    PN_model = PolicyNetwork(input_dim=INPUT_DIM, dim1=5000, output_dim=2323)
    start = time.time()

    try:
        print('fp is: {}'.format(fp))
        PN_model.load_state_dict(torch.load(fp))
        print('load PN model success.')
    except:
        print('Cannot load the model!!!!!!!!!\n fp is: {}'.format(fp))
        # if A.playby == 'policy':
        #     sys.exit()

    if A.optim == 'Adam':
        optimizer = torch.optim.Adam(PN_model.parameters(), lr=A.lr, weight_decay=A.decay)
    if A.optim == 'SGD':
        optimizer = torch.optim.SGD(PN_model.parameters(), lr=A.lr, weight_decay=A.decay)
    if A.optim == 'RMS':
        optimizer = torch.optim.RMSprop(PN_model.parameters(), lr=A.lr, weight_decay=A.decay)

    numpy_list = list()
    NUMPY_COUNT = 0

    sample_dict = defaultdict(list)
    conversation_length_list = list()
    # endAt = len(the_valid_list) if cfg.eval == 0 else len(the_test_list)
    # endAt = 500
    # print(f'endAt: {endAt}')
    print('-'*10)
    print('Train mode' if cfg.eval == 0 else 'Test mode')
    print('-' * 10)

    # cfg.actionProb = epi_count/endAt
    # if A.test == 1 or A.eval == 1:
    u, item = the_test_list[0]
    current_FM_model = copy.deepcopy(FM_model)
    param1, param2 = list(), list()
    param3 = list()
    param4 = list()
    i = 0
    for name, param in current_FM_model.named_parameters():
        param4.append(param)
        # print(name, param)
        if i == 0:
            param1.append(param)
        else:
            param2.append(param)
        if i == 2:
            param3.append(param)
        i += 1
    optimizer1_fm = torch.optim.Adagrad(param1, lr=0.01, weight_decay=A.decay)
    optimizer2_fm = torch.optim.SGD(param4, lr=0.001, weight_decay=A.decay)

    user_id = int(u)
    item_id = int(item)

    write_fp = '../../data/interaction-log/{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
        A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma, A.playby, A.strategy, A.TopKTaxo, A.trick,
        A.eval, A.initeval,
        A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask)

    choose_pool = cfg.item_dict[str(item_id)]['categories']
    print(choose_pool)
    if A.purpose not in ['pretrain', 'fmdata']:
        # this means that: we are not collecting data for pretraining or fm data
        # then we only randomly choose one start attribute to ask!
        choose_pool = [random.choice(choose_pool)]
    # for item_id in cfg.item_dict:
    #     choose_pool = [k for k in cfg.item_dict_rel[str(item_id)] if len(cfg.item_dict_rel[str(item_id)][k]) != 0]
    #     if choose_pool == None:
    #         print(item_id)
    print(f'user id: {user_id}\titem id: {item_id}')
    # choose_pool = [k for k in cfg.item_dict_rel[str(item_id)] if len(cfg.item_dict_rel[str(item_id)][k]) != 0]
    # choose_pool = random.choice(choose_pool)

    for c in choose_pool:
        # with open(write_fp, 'a+') as f:
        #     f.write(
        #         'Starting new\nuser ID: {}, item ID: {} episode count: {}, feature: {}\n'.format(user_id, item_id,
        #                                                                                          epi_count,
        #                                                                                          cfg.item_dict[
        #                                                                                              str(item_id)][
        #                                                                                              'categories']))
        start_facet = c
        log_prob_list, rewards, hl = run_one_episode(current_FM_model, user_id, item_id, A.mt, False, write_fp,
                                                     A.strategy, A.TopKTaxo,
                                                     PN_model, gamma, A.trick, A.mini,
                                                     optimizer1_fm, optimizer2_fm, A.alwaysupdate, start_facet,
                                                     A.mask, sample_dict)
        # if cfg.eval == 0:
        #     with open(f'../../data/train_rec{epoch}.txt', 'a+') as f:
        #         # f.writelines(str(rewards.tolist()))
        #         f.writelines(str(hl))
        #         f.writelines('\n')
        # else:
        #     with open('../../data/test_rec.txt', 'a+') as f:
        #         # f.writelines(str(rewards.tolist()))
        #         f.writelines(str(hl))
        #         f.writelines('\n')

        # update PN model
        if A.playby == 'policy' and A.eval != 1:
            update_PN_model(PN_model, log_prob_list, rewards, optimizer)
            print('updated PN model')
            current_length = len(log_prob_list)
            conversation_length_list.append(current_length)
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(
        description="Run conversational recommendation.")
    parser.add_argument('-mt', type=int, dest='mt', help='MAX_TURN')
    parser.add_argument('-playby', type=str, dest='playby', help='playby')
    # options include:
    # AO: (Ask Only and recommend by probability)
    # RO: (Recommend Only)
    # policy: (action decided by our policy network)
    # sac: (action decided by our SAC network)
    parser.add_argument('-fmCommand',
                        type=str,
                        dest='fmCommand',
                        help='fmCommand')
    # the command used for FM, check out /EAR/lastfm/FM/
    parser.add_argument('-optim', type=str, dest='optim', help='optimizer')
    # the optimizer for policy network
    parser.add_argument('-actor_lr',
                        type=float,
                        dest='actor_lr',
                        help='actor learning rate')
    # learning rate of Actor network
    parser.add_argument('-critic_lr',
                        type=float,
                        dest='critic_lr',
                        help='critic learning rate')
    # learning rate of the Critic network
    parser.add_argument('-actor_decay',
                        type=float,
                        dest='actor_decay',
                        help='actor weight decay')
    parser.add_argument('-decay',
                        type=float,
                        dest='decay',
                        help="weight decay for FM model")
    # weight decay
    parser.add_argument('-critic_decay',
                        type=float,
                        dest='critic_decay',
                        help='critic weight decay')
    parser.add_argument('-TopKTaxo',
                        type=int,
                        dest='TopKTaxo',
                        help='TopKTaxo')
    # how many 2-layer feature will represent a big feature. Only Yelp dataset use this param, lastFM have no effect.
    parser.add_argument('-gamma', type=float, dest='gamma', help='gamma')
    # gamma of training policy network
    parser.add_argument('-trick', type=int, dest='trick', help='trick')
    # whether use normalization in training policy network
    parser.add_argument('-startFrom',
                        type=int,
                        dest='startFrom',
                        help='startFrom')
    # startFrom which user-item interaction pair
    parser.add_argument('-endAt', type=int, dest='endAt', help='endAt')
    # endAt which user-item interaction pair
    parser.add_argument('-strategy',
                        type=str,
                        dest='strategy',
                        help='strategy')
    # strategy to choose question to ask, only have effect
    parser.add_argument('-eval', type=int, dest='eval', help='eval')
    # whether current run is for evaluation
    parser.add_argument('-mini', type=int, dest='mini', help='mini')
    # means `mini`-batch update the FM
    parser.add_argument('-alwaysupdate',
                        type=int,
                        dest='alwaysupdate',
                        help='alwaysupdate')
    # means always mini-batch update the FM, alternative is that only do the update for 1 time in a session.
    # we leave this exploration tof follower of our work.
    parser.add_argument('-initeval',
                        type=int,
                        dest='initeval',
                        help='initeval')
    # whether do the evaluation for the `init`ial version of policy network (directly after pre-train)
    parser.add_argument('-upoptim', type=str, dest='upoptim', help='upoptim')
    # optimizer for reflection stafe
    parser.add_argument('-upcount', type=int, dest='upcount', help='upcount')
    # how many times to do reflection
    parser.add_argument('-upreg', type=float, dest='upreg', help='upreg')
    # regularization term in
    parser.add_argument('-code', type=str, dest='code', help='code')
    # We use it to give each run a unique identifier.
    parser.add_argument('-purpose', type=str, dest='purpose', help='purpose')
    # options: pretrain, others
    parser.add_argument('-mod', type=str, dest='mod', help='mod')
    # options: CRM, EAR
    parser.add_argument('-mask', type=int, dest='mask', help='mask')
    # use for ablation study, 1, 2, 3, 4 represent our four segments, {ent, sim, his, len}
    parser.add_argument('-use_sac', type=bool, dest='use_sac', help='use_sac')
    # true if the RL module uses SAC

    A = parser.parse_args()

    cfg.change_param(playby=A.playby,
                     eval=A.eval,
                     update_count=A.upcount,
                     update_reg=A.upreg,
                     purpose=A.purpose,
                     mod=A.mod,
                     mask=A.mask)

    random.seed(1)

    # we random shuffle and split the valid and test set, for Action Stage training and evaluation respectively, to avoid the bias in the dataset.
    all_list = cfg.valid_list + cfg.test_list
    print('The length of all list is: {}'.format(len(all_list)))
    random.shuffle(all_list)
    the_valid_list = all_list[:int(len(all_list) / 2.0)]
    the_test_list = all_list[int(len(all_list) / 2.0):]

    gamma = A.gamma
    FM_model = cfg.FM_model

    if A.eval == 1:
        if A.mod == 'ear':
            fp = '../../data/PN-model-ear/PN-model-ear.txt'
        if A.mod == 'crm':
            fp = '../../data/PN-model-crm/PN-model-crm.txt'
        if A.initeval == 1:
            if A.mod == 'ear':
                fp = '../../data/PN-model-ear/pretrain-model.pt'
            if A.mod == 'crm':
                fp = '../../data/PN-model-crm/pretrain-model.pt'
    else:
        # means training
        if A.mod == 'ear':
            fp = '../../data/PN-model-ear/pretrain-model.pt'
        if A.mod == 'crm':
            fp = '../../data/PN-model-crm/pretrain-model.pt'
    INPUT_DIM = 0
    if A.mod == 'ear':
        INPUT_DIM = 81
    if A.mod == 'crm':
        INPUT_DIM = 590
    print('fp is: {}'.format(fp))

    # Initialie the policy network to either PolicyNetwork or SAC-Net
    if not A.use_sac:
        PN_model = PolicyNetwork(input_dim=INPUT_DIM, dim1=64, output_dim=34)
        start = time.time()

        try:
            PN_model.load_state_dict(torch.load(fp))
            print('Now Load PN pretrain from {}, takes {} seconds.'.format(
                fp,
                time.time() - start))
        except:
            print('Cannot load the model!!!!!!!!!\n fp is: {}'.format(fp))
            if A.playby == 'policy':
                sys.exit()

        if A.optim == 'Adam':
            optimizer = torch.optim.Adam(PN_model.parameters(),
                                         lr=A.lr,
                                         weight_decay=A.decay)
        if A.optim == 'SGD':
            optimizer = torch.optim.SGD(PN_model.parameters(),
                                        lr=A.lr,
                                        weight_decay=A.decay)
        if A.optim == 'RMS':
            optimizer = torch.optim.RMSprop(PN_model.parameters(),
                                            lr=A.lr,
                                            weight_decay=A.decay)

    else:
        PN_model = SAC_Net(input_dim=INPUT_DIM,
                           dim1=64,
                           output_dim=34,
                           actor_lr=A.actor_lr,
                           critic_lr=A.critic_lr,
                           discount_rate=gamma,
                           actor_w_decay=A.actor_decay,
                           critic_w_decay=A.critic_decay)

    numpy_list = list()
    rewards_list = list()
    NUMPY_COUNT = 0

    sample_dict = defaultdict(list)
    conversation_length_list = list()
    for epi_count in range(A.startFrom, A.endAt):
        if epi_count % 1 == 0:
            print('-----\nIt has processed {} episodes'.format(epi_count))
        start = time.time()

        u, item, r = the_valid_list[epi_count]

        # if A.test == 1 or A.eval == 1:
        if A.eval == 1:
            u, item, r = the_test_list[epi_count]

        if A.purpose == 'fmdata':
            u, item = 0, epi_count

        if A.purpose == 'pretrain':
            u, item, r = cfg.train_list[epi_count]

        current_FM_model = copy.deepcopy(FM_model)
        param1, param2 = list(), list()
        i = 0
        for name, param in current_FM_model.named_parameters():
            # print(name, param)
            if i == 0:
                param1.append(param)
            else:
                param2.append(param)
            i += 1

        optimizer1_fm, optimizer2_fm = None, None
        if A.purpose != 'fmdata':
            optimizer1_fm = torch.optim.Adagrad(param1,
                                                lr=0.01,
                                                weight_decay=A.decay)
            if A.upoptim == 'Ada':
                optimizer2_fm = torch.optim.Adagrad(param2,
                                                    lr=0.01,
                                                    weight_decay=A.decay)
            if A.upoptim == 'SGD':
                optimizer2_fm = torch.optim.SGD(param2,
                                                lr=0.001,
                                                weight_decay=A.decay)

        user_id = int(u)
        item_id = int(item)

        big_feature_list = list()

        for k, v in cfg.taxo_dict.items():
            if len(
                    set(v).intersection(
                        set(cfg.item_dict[str(item_id)]['categories']))) > 0:
                big_feature_list.append(k)

        write_fp = '../../data/interaction-log/{}/v5-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            A.mod.lower(), A.code, A.startFrom, A.endAt, A.actor_lr, A.gamma,
            A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
            A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask)

        if cfg.item_dict[str(item_id)]['RestaurantsPriceRange2'] is not None:
            choose_pool = ['stars', 'city', 'RestaurantsPriceRange2'
                           ] + big_feature_list
        else:
            choose_pool = ['stars', 'city'] + big_feature_list

        choose_pool_original = choose_pool

        if A.purpose not in ['pretrain', 'fmdata']:
            # this means that: we are not collecting data for pretraining or fm data
            # then we only randomly choose one start attribute to ask!
            choose_pool = [random.choice(choose_pool)]

        for c in choose_pool:
            with open(write_fp, 'a') as f:
                f.write(
                    'Starting new\nuser ID: {}, item ID: {} episode count: {}, feature: {}\n'
                    .format(user_id, item_id, epi_count,
                            cfg.item_dict[str(item_id)]['categories']))
            start_facet = c
            if A.purpose != 'pretrain' and A.playby != 'sac':
                log_prob_list, rewards = run_one_episode(
                    current_FM_model, user_id, item_id, A.mt, False, write_fp,
                    A.strategy, A.TopKTaxo, PN_model, gamma, A.trick, A.mini,
                    optimizer1_fm, optimizer2_fm, A.alwaysupdate, start_facet,
                    A.mask, sample_dict, choose_pool_original)
            else:
                if A.playby != 'sac':
                    current_np = run_one_episode(
                        current_FM_model, user_id, item_id, A.mt, False,
                        write_fp, A.strategy, A.TopKTaxo, PN_model, gamma,
                        A.trick, A.mini, optimizer1_fm, optimizer2_fm,
                        A.alwaysupdate, start_facet, A.mask, sample_dict,
                        choose_pool_original)
                    numpy_list += current_np

                else:
                    current_np, current_reward = run_one_episode(
                        current_FM_model, user_id, item_id, A.mt, False,
                        write_fp, A.strategy, A.TopKTaxo, PN_model, gamma,
                        A.trick, A.mini, optimizer1_fm, optimizer2_fm,
                        A.alwaysupdate, start_facet, A.mask, sample_dict,
                        choose_pool_original)
                    rewards_list += current_reward
                    numpy_list += current_np

            # update PN model
            if A.playby == 'policy' and A.eval != 1:
                update_PN_model(PN_model, log_prob_list, rewards, optimizer)
                print('updated PN model')
                current_length = len(log_prob_list)
                conversation_length_list.append(current_length)
            # end update
            rewards = current_reward
            # Update SAC
            if A.purpose != 'pretrain':
                with open(write_fp, 'a') as f:
                    f.write('Big features are: {}\n'.format(choose_pool))
                    if rewards is not None:
                        f.write('reward is: {}\n'.format(
                            rewards.data.numpy().tolist()))
                    f.write('WHOLE PROCESS TAKES: {} SECONDS\n'.format(
                        time.time() - start))

        # Write to pretrain numpy.
        if A.purpose == 'pretrain':
            if A.playby != 'sac':
                if len(numpy_list) > 5000:
                    with open(
                            '../../data/pretrain-numpy-data-{}/segment-{}-start-{}-end-{}.pk'
                            .format(A.mod, NUMPY_COUNT, A.startFrom,
                                    A.endAt), 'wb') as f:
                        pickle.dump(numpy_list, f)
                        print('Have written 5000 numpy arrays!')
                    NUMPY_COUNT += 1
                    numpy_list = list()
            else:
                # In SAC mode, collect both numpy_list and rewards_list as training data
                if len(numpy_list) > 5000 or len(rewards_list) > 5000:
                    assert len(rewards_list) == len(
                        numpy_list
                    ), "rewards and state-action pairs have different size!"
                    directory = '../../data/pretrain-sac-numpy-data-{}/segment-{}-start-{}-end-{}.pk'.format(
                        A.mod, NUMPY_COUNT, A.startFrom, A.endAt)
                    rewards_directory = '../../data/pretrain-sac-reward-data-{}/segment-{}-start-{}-end-{}.pk'.format(
                        A.mod, NUMPY_COUNT, A.startFrom, A.endAt)
                    with open(directory, 'wb') as f:
                        pickle.dump(numpy_list, f)
                        print('Have written 5000 numpy arrays for SAC!')
                    with open(rewards_directory, 'wb') as f:
                        pickle.dump(rewards_list, f)
                        print('Have written 5000 rewrds for SAC!')
                    NUMPY_COUNT += 1
                    numpy_list = list()
                    rewards_list = list()

        # numpy_list is a list of list.
        # e.g. numpy_list[0][0]: int, indicating the action.
        # numpy_list[0][1]: a one-d array of length 89 for EAR, and 33 for CRM.
        # end write

        # Write sample dict:
        if A.purpose == 'fmdata' and A.playby != 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/start-{}-end-{}.json'.format(
                            A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)
        # end write
        if A.purpose == 'fmdata' and A.playby == 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/valid-start-{}-end-{}.json'.
                        format(A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)

        check_span = 500
        if epi_count % check_span == 0 and epi_count >= 3 * check_span and cfg.eval != 1 and A.purpose != 'pretrain':
            # We use AT (average turn of conversation) as our stopping criterion
            # in training mode, save RL model periodically
            # save model first
            PATH = '../../data/PN-model-{}/v5-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                A.mod.lower(), A.code, A.startFrom, A.endAt, A.actor_lr,
                A.gamma, A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval,
                A.initeval, A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask,
                epi_count)
            torch.save(PN_model.state_dict(), PATH)
            print('Model saved at {}'.format(PATH))

            # a0 = conversation_length_list[epi_count - 4 * check_span: epi_count - 3 * check_span]
            a1 = conversation_length_list[epi_count -
                                          3 * check_span:epi_count -
                                          2 * check_span]
            a2 = conversation_length_list[epi_count -
                                          2 * check_span:epi_count -
                                          1 * check_span]
            a3 = conversation_length_list[epi_count - 1 * check_span:]
            a1 = np.mean(np.array(a1))
            a2 = np.mean(np.array(a2))
            a3 = np.mean(np.array(a3))

            with open(write_fp, 'a') as f:
                f.write('$$$current turn: {}, a3: {}, a2: {}, a1: {}\n'.format(
                    epi_count, a3, a2, a1))
            print('current turn: {}, a3: {}, a2: {}, a1: {}'.format(
                epi_count, a3, a2, a1))

            num_interval = int(epi_count / check_span)
            for i in range(num_interval):
                ave = np.mean(
                    np.array(conversation_length_list[i * check_span:(i + 1) *
                                                      check_span]))
                print('start: {}, end: {}, average: {}'.format(
                    i * check_span, (i + 1) * check_span, ave))
                PATH = '../../data/PN-model-{}/v5-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                    A.mod.lower(), A.code, A.startFrom, A.endAt, A.actor_lr,
                    A.gamma, A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval,
                    A.initeval, A.mini, A.alwaysupdate, A.upcount, A.upreg,
                    A.mask, (i + 1) * check_span)
                print('Model saved at: {}'.format(PATH))

            if a3 > a1 and a3 > a2:
                print('Early stop of RL!')
                exit()
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description="Pretrain Policy Network")
    parser.add_argument('-inputdim',
                        type=int,
                        dest='inputdim',
                        help='input dimension')
    parser.add_argument('-hiddendim',
                        type=int,
                        dest='hiddendim',
                        help='hidden dimension',
                        default=1500)
    parser.add_argument('-outputdim',
                        type=int,
                        dest='outputdim',
                        help='output dimension')
    parser.add_argument('-bs',
                        type=int,
                        dest='bs',
                        help='batch size',
                        default=512)
    parser.add_argument('-optim',
                        type=str,
                        dest='optim',
                        help='optimizer choice',
                        default='Adam')
    parser.add_argument('-lr',
                        type=float,
                        dest='lr',
                        help='learning rate',
                        default=0.001)
    parser.add_argument('-decay',
                        type=float,
                        dest='decay',
                        help='weight decay',
                        default=0)
    parser.add_argument('-mod',
                        type=str,
                        dest='mod',
                        help='mod',
                        default='ear')  # ear crm

    A = parser.parse_args()
    print('Arguments loaded!')
    if A.mod == 'ear':
        #inputdim = 89
        #inputdim = 8787
        inputdim = len(cfg.tag_map) * 2 + cfg.MAX_TURN + 8
        with open('../../data/FM-train-data/tag_question_map.json', 'r') as f:
            outputdim = len(json.load(f)) + 1
    else:
        inputdim = 33
    print("hi: ", inputdim)
    PN = PolicyNetwork(input_dim=inputdim,
                       dim1=A.hiddendim,
                       output_dim=outputdim)
    print(
        f"input_dim: {inputdim}\thidden_dim: {A.hiddendim}\toutput_dim: {outputdim}"
    )
    cuda_(PN)
    print('Model on GPU')
    data_list = list()

    #dir = '../../data/pretrain-numpy-data-{}'.format(A.mod)
    dir = '../../data/RL-pretrain-data-{}'.format(A.mod)

    files = os.listdir(dir)
    file_paths = [dir + '/' + f for f in files]

    i = 0
    for fp in file_paths:
        with open(fp, 'rb') as f:
            try:
                data_list += pickle.load(f)
                i += 1
            except:
                pass
    print('total files: {}'.format(i))
    data_list = data_list[:int(len(data_list) / 1.5)]
    print('length of data list is: {}'.format(len(data_list)))

    random.shuffle(data_list)

    train_list = data_list
    valid_list = []
    test_list = []
    # train_list = data_list[: int(len(data_list) * 0.7)]
    # valid_list = data_list[int(len(data_list) * 0.7): int(len(data_list) * 0.9)]
    #
    # test_list = data_list[int(len(data_list) * 0.9):]
    # print('train length: {}, valid length: {}, test length: {}'.format(len(train_list), len(valid_list), len(test_list)))
    # sleep(1)  # let you see this

    if A.optim == 'Ada':
        optimizer = torch.optim.Adagrad(PN.parameters(),
                                        lr=A.lr,
                                        weight_decay=A.decay)
    if A.optim == 'Adam':
        optimizer = torch.optim.Adam(PN.parameters(),
                                     lr=A.lr,
                                     weight_decay=A.decay)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(8):
        print(f"epoch:{epoch}")
        random.shuffle(train_list)
        model_name = '../../data/PN-model-{}/pretrain-model.pt'.format(A.mod)
        train(A.bs, train_list, valid_list, test_list, optimizer, PN,
              criterion, epoch, model_name)
Ejemplo n.º 10
0
    def __init__(self, user_name, item_name):
        self.mt = 5
        self.playby = "policy"
        # options include:
        # AO: (Ask Only and recommend by probability)
        # RO: (Recommend Only)
        # policy: (action decided by our policy network)
        self.fmCommand = 8
        # the command used for FM, check out /EAR/lastfm/FM/
        self.optim = "SGD"
        # the optimizer for policy network
        self.lr = 0.001
        # learning rate of policy network
        self.decay = 0
        # weight decay
        self.TopKTaxo = 3
        # how many 2-layer feature will represent a big feature. Only Yelp dataset use this param, lastFM have no effect.
        self.gamma = 0.7
        # gamma of training policy network
        self.trick = 0
        # whether use normalization in training policy network
        self.strategy = "maxent"
        # strategy to choose question to ask, only have effect
        self.eval = 1
        # whether current run is for evaluation
        self.mini = 0
        # means `mini`-batch update the FM
        self.alwaysupdate = 0
        # means always mini-batch update the FM, alternative is that only do the update for 1 time in a session.
        # we leave this exploration tof follower of our work.
        self.initeval = 0
        # whether do the evaluation for the `init`ial version of policy network (directly after pre-train,default=)
        self.upoptim = "SGD"
        # optimizer for reflection stafe
        self.upcount = 0
        # how many times to do reflection
        self.upreg = 0.001
        # regularization term in
        self.code = "stable"
        # We use it to give each run a unique identifier.
        self.purpose = "train"
        # options: pretrain, others
        self.mod = "ear"
        # options: CRM, EAR
        self.mask = 0
        # use for ablation study, 1, 2, 3, 4 represent our four segments, {ent, sim, his, len}

        cfg.change_param(playby=self.playby,
                         eval=self.eval,
                         update_count=self.upcount,
                         update_reg=self.upreg,
                         purpose=self.purpose,
                         mod=self.mod,
                         mask=self.mask)

        gamma = self.gamma
        FM_model = cfg.FM_model

        if self.mod == 'ear':
            fp = '../../data/PN-model-ear/model-epoch0'
        if self.mod == 'ear':
            INPUT_DIM = len(cfg.tag_map) * 2 + self.mt + 8
        self.PN_model = PolicyNetwork(input_dim=INPUT_DIM,
                                      dim1=1500,
                                      output_dim=len(cfg.tag_map) + 1)

        try:
            print('fp is: {}'.format(fp))
            self.PN_model.load_state_dict(torch.load(fp))
            print('load PN model success.')
        except:
            print('Cannot load the model!!!!!!!!!\n fp is: {}'.format(fp))
            sys.exit()

        if self.optim == 'Adam':
            optimizer = torch.optim.Adam(self.PN_model.parameters(),
                                         lr=self.lr,
                                         weight_decay=self.decay)
        if self.optim == 'SGD':
            optimizer = torch.optim.SGD(self.PN_model.parameters(),
                                        lr=self.lr,
                                        weight_decay=self.decay)
        if self.optim == 'RMS':
            optimizer = torch.optim.RMSprop(self.PN_model.parameters(),
                                            lr=self.lr,
                                            weight_decay=self.decay)

        self.sample_dict = defaultdict(list)
        self.conversation_length_list = list()
        # print('-'*10)
        # print('Train mode' if cfg.eval == 0 else 'Test mode')
        # print('-' * 10)

        # cfg.actionProb = epi_count/endAt
        # if A.test == 1 or A.eval == 1:

        # input
        self.u = user_name
        self.item = item_name

        self.current_FM_model = copy.deepcopy(FM_model)
        param1, param2 = list(), list()
        param3 = list()
        param4 = list()
        i = 0
        for name, param in self.current_FM_model.named_parameters():
            param4.append(param)
            # print(name, param)
            if i == 0:
                param1.append(param)
            else:
                param2.append(param)
            if i == 2:
                param3.append(param)
            i += 1
        self.optimizer1_fm = torch.optim.Adagrad(param1,
                                                 lr=0.01,
                                                 weight_decay=self.decay)
        self.optimizer2_fm = torch.optim.SGD(param4,
                                             lr=0.001,
                                             weight_decay=self.decay)

        self.user_id = int(self.u)
        self.item_id = int(self.item)

        # Initialize the user
        self.the_user = user(self.user_id, self.item_id)

        self.numpy_list = list()
        self.log_prob_list, self.reward_list = Variable(torch.Tensor()), list()
        self.action_tracker, self.candidate_length_tracker = list(), list()

        self.the_agent = agent.agent(
            self.current_FM_model, self.user_id, self.item_id, False, "",
            self.strategy, self.TopKTaxo, self.numpy_list, self.PN_model,
            self.log_prob_list, self.action_tracker,
            self.candidate_length_tracker, self.mini, self.optimizer1_fm,
            self.optimizer2_fm, self.alwaysupdate, self.sample_dict)

        self.agent_utterance = None

        choose_pool = cfg.item_dict[str(self.item_id)]['categories']
        c = random.choice(choose_pool)
        # print(f'user id: {user_id}\titem id: {item_id}')

        self.start_facet = c
Ejemplo n.º 11
0
class chat_model:
    def __init__(self, user_name, item_name):
        self.mt = 5
        self.playby = "policy"
        # options include:
        # AO: (Ask Only and recommend by probability)
        # RO: (Recommend Only)
        # policy: (action decided by our policy network)
        self.fmCommand = 8
        # the command used for FM, check out /EAR/lastfm/FM/
        self.optim = "SGD"
        # the optimizer for policy network
        self.lr = 0.001
        # learning rate of policy network
        self.decay = 0
        # weight decay
        self.TopKTaxo = 3
        # how many 2-layer feature will represent a big feature. Only Yelp dataset use this param, lastFM have no effect.
        self.gamma = 0.7
        # gamma of training policy network
        self.trick = 0
        # whether use normalization in training policy network
        self.strategy = "maxent"
        # strategy to choose question to ask, only have effect
        self.eval = 1
        # whether current run is for evaluation
        self.mini = 0
        # means `mini`-batch update the FM
        self.alwaysupdate = 0
        # means always mini-batch update the FM, alternative is that only do the update for 1 time in a session.
        # we leave this exploration tof follower of our work.
        self.initeval = 0
        # whether do the evaluation for the `init`ial version of policy network (directly after pre-train,default=)
        self.upoptim = "SGD"
        # optimizer for reflection stafe
        self.upcount = 0
        # how many times to do reflection
        self.upreg = 0.001
        # regularization term in
        self.code = "stable"
        # We use it to give each run a unique identifier.
        self.purpose = "train"
        # options: pretrain, others
        self.mod = "ear"
        # options: CRM, EAR
        self.mask = 0
        # use for ablation study, 1, 2, 3, 4 represent our four segments, {ent, sim, his, len}

        cfg.change_param(playby=self.playby,
                         eval=self.eval,
                         update_count=self.upcount,
                         update_reg=self.upreg,
                         purpose=self.purpose,
                         mod=self.mod,
                         mask=self.mask)

        gamma = self.gamma
        FM_model = cfg.FM_model

        if self.mod == 'ear':
            fp = '../../data/PN-model-ear/model-epoch0'
        if self.mod == 'ear':
            INPUT_DIM = len(cfg.tag_map) * 2 + self.mt + 8
        self.PN_model = PolicyNetwork(input_dim=INPUT_DIM,
                                      dim1=1500,
                                      output_dim=len(cfg.tag_map) + 1)

        try:
            print('fp is: {}'.format(fp))
            self.PN_model.load_state_dict(torch.load(fp))
            print('load PN model success.')
        except:
            print('Cannot load the model!!!!!!!!!\n fp is: {}'.format(fp))
            sys.exit()

        if self.optim == 'Adam':
            optimizer = torch.optim.Adam(self.PN_model.parameters(),
                                         lr=self.lr,
                                         weight_decay=self.decay)
        if self.optim == 'SGD':
            optimizer = torch.optim.SGD(self.PN_model.parameters(),
                                        lr=self.lr,
                                        weight_decay=self.decay)
        if self.optim == 'RMS':
            optimizer = torch.optim.RMSprop(self.PN_model.parameters(),
                                            lr=self.lr,
                                            weight_decay=self.decay)

        self.sample_dict = defaultdict(list)
        self.conversation_length_list = list()
        # print('-'*10)
        # print('Train mode' if cfg.eval == 0 else 'Test mode')
        # print('-' * 10)

        # cfg.actionProb = epi_count/endAt
        # if A.test == 1 or A.eval == 1:

        # input
        self.u = user_name
        self.item = item_name

        self.current_FM_model = copy.deepcopy(FM_model)
        param1, param2 = list(), list()
        param3 = list()
        param4 = list()
        i = 0
        for name, param in self.current_FM_model.named_parameters():
            param4.append(param)
            # print(name, param)
            if i == 0:
                param1.append(param)
            else:
                param2.append(param)
            if i == 2:
                param3.append(param)
            i += 1
        self.optimizer1_fm = torch.optim.Adagrad(param1,
                                                 lr=0.01,
                                                 weight_decay=self.decay)
        self.optimizer2_fm = torch.optim.SGD(param4,
                                             lr=0.001,
                                             weight_decay=self.decay)

        self.user_id = int(self.u)
        self.item_id = int(self.item)

        # Initialize the user
        self.the_user = user(self.user_id, self.item_id)

        self.numpy_list = list()
        self.log_prob_list, self.reward_list = Variable(torch.Tensor()), list()
        self.action_tracker, self.candidate_length_tracker = list(), list()

        self.the_agent = agent.agent(
            self.current_FM_model, self.user_id, self.item_id, False, "",
            self.strategy, self.TopKTaxo, self.numpy_list, self.PN_model,
            self.log_prob_list, self.action_tracker,
            self.candidate_length_tracker, self.mini, self.optimizer1_fm,
            self.optimizer2_fm, self.alwaysupdate, self.sample_dict)

        self.agent_utterance = None

        choose_pool = cfg.item_dict[str(self.item_id)]['categories']
        c = random.choice(choose_pool)
        # print(f'user id: {user_id}\titem id: {item_id}')

        self.start_facet = c

    def first_conversation(self, user_response):  # Yes/No/Hit/Reject
        data = dict()
        data['facet'] = self.start_facet
        start_signal = message(cfg.AGENT, cfg.USER, cfg.EPISODE_START, data)
        user_utterance = self.the_user.response(start_signal, user_response)

        # user_utterance.message_type: cfg.REJECT_REC, cfg.ACCEPT_REC, cfg.INFORM_FACET

        # if user_utterance.message_type == cfg.ACCEPT_REC:
        #     self.the_agent.history_list.append(2)
        #     s = 'Rec Success! in Turn:' + str(self.the_agent.turn_count) + '.'
        print(user_utterance.message_type)
        self.agent_utterance = self.the_agent.response(user_utterance)

        print(self.agent_utterance.message_type)

        # agent_utterance.message_type: cfg.ASK_FACET, cfg.MAKE_REC

        if self.agent_utterance.message_type == cfg.ASK_FACET:
            s = "D" + str(self.agent_utterance.data['facet'])

        elif self.agent_utterance.message_type == cfg.MAKE_REC:
            s = []
            s.append("r")
            k = []
            for i, j in enumerate(self.agent_utterance.data["rec_list"][:5]):
                # j = meta_data[r_item_map[int(j)]]["title"]
                j = str(j)
                #j = j.split(" ")
                # if type(j) == list:
                #     j = j[-3]+" "+j[-2]+" "+j[-1]
                # # if i != len(self.agent_utterance.data["rec_list"])-1:
                # if i != 4:
                #     s += str(i+1)+". "+str(j) + "\n"
                # else:
                #     s += str(i+1)+". "+str(j)
                k.append(j)
            s.append(k)

        self.the_agent.turn_count += 1

        return s

    def conversation(self, user_response):  # Yes/No/Hit/Reject
        if self.the_agent.turn_count < self.mt:
            user_utterance = self.the_user.response(self.agent_utterance,
                                                    user_response)

            if user_utterance.message_type == cfg.ACCEPT_REC:
                self.the_agent.history_list.append(2)
                s = 'Rec Success! in Turn:' + str(
                    self.the_agent.turn_count) + '.'
                return s

            self.agent_utterance = self.the_agent.response(user_utterance)

            if self.agent_utterance.message_type == cfg.ASK_FACET:
                s = "D" + str(self.agent_utterance.data['facet'])

            elif self.agent_utterance.message_type == cfg.MAKE_REC:
                s = []
                s.append("r")
                k = []
                for i, j in enumerate(
                        self.agent_utterance.data["rec_list"][:5]):
                    # j = meta_data[r_item_map[int(j)]]["title"]
                    j = str(j)
                    # j = j.split(" ")
                    # if type(j) == list:
                    #     j = j[-3]+" "+j[-2]+" "+j[-1]
                    # # if i != len(self.agent_utterance.data["rec_list"])-1:
                    # if i != 4:
                    #     s += str(i+1)+". "+str(j) + "\n"
                    # else:
                    #     s += str(i+1)+". "+str(j)
                    k.append(j)
                s.append(k)

            self.the_agent.turn_count += 1

        if self.the_agent.turn_count >= self.mt:  # 改成>的話,mt才是5,但要等Model 那邊重train,agent那邊max_turn < 5-2也要改
            self.the_agent.history_list.append(-2)
            s = "Already meet the max turn of conversation: " + str(self.mt)

        return s