Beispiel #1
0
# -*- coding: utf-8 -*-
import pickle
import torch
import argparse

import time
import numpy as np
import json

from pn import PolicyNetwork
import copy


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        y = m.in_features
        m.weight.data.normal_(0.0, 1 / np.sqrt(y))
        m.bias.data.fill_(0)


PN_model = PolicyNetwork(input_dim=29, dim1=64, output_dim=12)
PN_model.apply(weights_init_normal)
torch.save(PN_model.state_dict(), '../PN-model-ours/PN-model-ours.txt')
torch.save(PN_model.state_dict(), '../PN-model-ours/pretrain-model.pt')
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser(
        description="Run conversational recommendation.")
    parser.add_argument('-mt',
                        type=int,
                        dest='mt',
                        help='MAX_TURN',
                        default=10)
    parser.add_argument('-playby',
                        type=str,
                        dest='playby',
                        help='playby',
                        default='policy')
    parser.add_argument('-optim',
                        type=str,
                        dest='optim',
                        help='optimizer',
                        default='SGD')
    parser.add_argument('-lr', type=float, dest='lr', help='lr', default=0.001)
    parser.add_argument('-decay',
                        type=float,
                        dest='decay',
                        help='decay',
                        default=0)
    parser.add_argument('-TopKTaxo',
                        type=int,
                        dest='TopKTaxo',
                        help='TopKTaxo',
                        default=3)
    parser.add_argument('-gamma',
                        type=float,
                        dest='gamma',
                        help='gamma',
                        default=0)
    parser.add_argument('-trick',
                        type=int,
                        dest='trick',
                        help='trick',
                        default=0)
    parser.add_argument('-startFrom',
                        type=int,
                        dest='startFrom',
                        help='startFrom',
                        default=0)
    parser.add_argument('-endAt',
                        type=int,
                        dest='endAt',
                        help='endAt',
                        default=171)  #test 171
    parser.add_argument('-strategy',
                        type=str,
                        dest='strategy',
                        help='strategy',
                        default='maxsim')
    parser.add_argument('-eval', type=int, dest='eval', help='eval', default=1)
    parser.add_argument('-mini', type=int, dest='mini', help='mini', default=1)
    parser.add_argument('-alwaysupdate',
                        type=int,
                        dest='alwaysupdate',
                        help='alwaysupdate',
                        default=1)
    parser.add_argument('-initeval',
                        type=int,
                        dest='initeval',
                        help='initeval',
                        default=0)
    parser.add_argument('-upoptim',
                        type=str,
                        dest='upoptim',
                        help='upoptim',
                        default='SGD')
    parser.add_argument('-upcount',
                        type=int,
                        dest='upcount',
                        help='upcount',
                        default=4)
    parser.add_argument('-upreg',
                        type=float,
                        dest='upreg',
                        help='upreg',
                        default=0.001)
    parser.add_argument('-code',
                        type=str,
                        dest='code',
                        help='code',
                        default='stable')
    parser.add_argument('-purpose',
                        type=str,
                        dest='purpose',
                        help='purpose',
                        default='train')
    parser.add_argument('-mod',
                        type=str,
                        dest='mod',
                        help='mod',
                        default='ours')
    parser.add_argument('-mask', type=int, dest='mask', help='mask', default=0)
    # use for ablation study

    A = parser.parse_args()

    cfg.change_param(playby=A.playby,
                     eval=A.eval,
                     update_count=A.upcount,
                     update_reg=A.upreg,
                     purpose=A.purpose,
                     mod=A.mod,
                     mask=A.mask)
    device = torch.device('cuda')
    random.seed(1)
    #    random.shuffle(cfg.valid_list)
    #    random.shuffle(cfg.test_list)
    the_valid_list_item = copy.copy(cfg.valid_list_item)
    the_valid_list_features = copy.copy(cfg.valid_list_features)
    the_valid_list_location = copy.copy(cfg.valid_list_location)

    the_test_list_item = copy.copy(cfg.test_list_item)
    the_test_list_features = copy.copy(cfg.test_list_features)
    the_test_list_location = copy.copy(cfg.test_list_location)

    the_train_list_item = copy.copy(cfg.train_list_item)
    the_train_list_features = copy.copy(cfg.train_list_features)
    the_train_list_location = copy.copy(cfg.train_list_location)
    #    random.shuffle(the_valid_list)
    #    random.shuffle(the_test_list)

    gamma = A.gamma
    transE_model = cfg.transE_model
    if A.eval == 1:

        if A.mod == 'ours':
            fp = '../data/PN-model-ours/PN-model-ours.txt'
        if A.initeval == 1:

            if A.mod == 'ours':
                fp = '../data/PN-model-ours/pretrain-model.pt'
    else:
        # means training
        if A.mod == 'ours':
            fp = '../data/PN-model-ours/pretrain-model.pt'

    INPUT_DIM = 0

    if A.mod == 'ours':
        INPUT_DIM = 29  #11+10+8
    PN_model = PolicyNetwork(input_dim=INPUT_DIM, dim1=64, output_dim=12)
    start = time.time()

    try:
        PN_model.load_state_dict(torch.load(fp))
        print('Now Load PN pretrain from {}, takes {} seconds.'.format(
            fp,
            time.time() - start))
    except:
        print('Cannot load the model!!!!!!!!!\n fp is: {}'.format(fp))
        if cfg.play_by == 'policy':
            sys.exit()

    if A.optim == 'Adam':
        optimizer = torch.optim.Adam(PN_model.parameters(),
                                     lr=A.lr,
                                     weight_decay=A.decay)
    if A.optim == 'SGD':
        optimizer = torch.optim.SGD(PN_model.parameters(),
                                    lr=A.lr,
                                    weight_decay=A.decay)
    if A.optim == 'RMS':
        optimizer = torch.optim.RMSprop(PN_model.parameters(),
                                        lr=A.lr,
                                        weight_decay=A.decay)

    numpy_list = list()
    NUMPY_COUNT = 0

    sample_dict = defaultdict(list)
    conversation_length_list = list()

    combined_num = 0
    total_turn = 0
    # start episode
    for epi_count in range(A.startFrom, A.endAt):
        if epi_count % 1 == 0:
            print('-----\nIt has processed {} episodes'.format(epi_count))

        start = time.time()

        current_transE_model = copy.deepcopy(transE_model)
        current_transE_model.to(device)

        param1, param2 = list(), list()
        i = 0
        for name, param in current_transE_model.named_parameters():
            if i == 0 or i == 1:
                param1.append(param)
                # param1: head, tail
            else:
                param2.append(param)
                # param2: time, category, cluster, type
            i += 1
        '''change to transE embedding'''
        optimizer1_transE, optimizer2_transE = None, None
        if A.purpose != 'fmdata':
            optimizer1_transE = torch.optim.Adagrad(param1,
                                                    lr=0.01,
                                                    weight_decay=A.decay)
            if A.upoptim == 'Ada':
                optimizer2_transE = torch.optim.Adagrad(param2,
                                                        lr=0.01,
                                                        weight_decay=A.decay)
            if A.upoptim == 'SGD':
                optimizer2_transE = torch.optim.SGD(param2,
                                                    lr=0.001,
                                                    weight_decay=A.decay)

        if A.purpose != 'pretrain':
            items = the_valid_list_item[epi_count]  #0 18 10 3
            features = the_valid_list_features[
                epi_count]  #3,21,2,1    21,12,2,1   22,7,2,1
            location = the_valid_list_location[epi_count]
            item_list = items.strip().split(' ')
            u = item_list[0]
            item = item_list[-1]
            if A.eval == 1:
                #                u, item, l = the_test_list_item[epi_count]
                items = the_test_list_item[epi_count]  #0 18 10 3
                features = the_test_list_features[
                    epi_count]  #3,21,2,1    21,12,2,1   22,7,2,1
                location = the_test_list_location[epi_count]
                item_list = items.strip().split(' ')
                u = item_list[0]
                item = item_list[-1]

            user_id = int(u)
            item_id = int(item)
            location_id = int(location)
        else:
            user_id = 0
            item_id = epi_count

        if A.purpose == 'pretrain':
            items = the_train_list_item[epi_count]  #0 18 10 3
            features = the_train_list_features[
                epi_count]  #3,21,2,1    21,12,2,1   22,7,2,1
            location = the_train_list_location[epi_count]
            item_list = items.strip().split(' ')
            u = item_list[0]
            item = item_list[-1]
            user_id = int(u)
            item_id = int(item)
            location_id = int(location)
        print("----target item: ", item_id)
        big_feature_list = list()
        '''update L2.json'''
        for k, v in cfg.taxo_dict.items():
            #            print (k,v)
            if len(
                    set(v).intersection(
                        set(cfg.item_dict[str(item_id)]
                            ['L2_Category_name']))) > 0:
                #                print(user_id, item_id) #433,122
                #                print (k)
                big_feature_list.append(k)

        write_fp = '../data/interaction-log/{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
            A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
            A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask)
        '''care the sequence of facet pool items'''
        if cfg.item_dict[str(item_id)]['POI_Type'] is not None:
            choose_pool = ['clusters', 'POI_Type'] + big_feature_list

        choose_pool_original = choose_pool

        if A.purpose not in ['pretrain', 'fmdata']:
            choose_pool = [random.choice(choose_pool)]

        # run the episode
        for c in choose_pool:
            start_facet = c
            with open(write_fp, 'a') as f:
                f.write(
                    'Starting new\nuser ID: {}, item ID: {} episode count: {}\n'
                    .format(user_id, item_id, epi_count))
            if A.purpose != 'pretrain':
                log_prob_list, rewards, success, turn_count, known_feature_category = run_one_episode(
                    current_transE_model, user_id, item_id, A.mt, False,
                    write_fp, A.strategy, A.TopKTaxo, PN_model, gamma, A.trick,
                    A.mini, optimizer1_transE, optimizer2_transE,
                    A.alwaysupdate, start_facet, A.mask, sample_dict,
                    choose_pool_original, features, items)
            else:
                current_np = run_one_episode(
                    current_transE_model, user_id, item_id, A.mt, False,
                    write_fp, A.strategy, A.TopKTaxo, PN_model, gamma, A.trick,
                    A.mini, optimizer1_transE, optimizer2_transE,
                    A.alwaysupdate, start_facet, A.mask, sample_dict,
                    choose_pool_original, features, items)
                numpy_list += current_np
        # end run

#         check POI type, recommend location id by star, check the success.
        if A.purpose != 'pretrain':
            total_turn += turn_count

        if A.purpose != 'pretrain':
            if success == True:
                if cfg.poi_dict[str(item_id)]['POI'] == 'Combined':
                    combined_num += 1
                    L2_Category_name_list = cfg.poi_dict[str(
                        item_id)]['L2_Category_name']
                    category_list = known_feature_category
                    #                    print (category_list)
                    possible_L2_list = []
                    for i in range(len(L2_Category_name_list)):
                        for j in range(len(category_list)):
                            if L2_Category_name_list[i] == category_list[j]:
                                temp_location = cfg.poi_dict[str(
                                    item_id)]['Location_id'][i]
                                if temp_location not in possible_L2_list:
                                    possible_L2_list.append(temp_location)
                    location_list = []
                    location_list = possible_L2_list
                    random_location_list = location_list.copy()
                    star_location_list = location_list.copy()
                    if location_id not in location_list:
                        #                        random_location_list += random.sample(cfg.poi_dict[str(item_id)]['Location_id'], int(len(cfg.poi_dict[str(item_id)]['Location_id'])/2))
                        #Can test random selection.
                        random_location_list = np.random.choice(
                            category_list,
                            int(len(category_list) / 2),
                            replace=False)
                        star_list = cfg.poi_dict[str(item_id)]['stars']
                        location_list = cfg.poi_dict[str(
                            item_id)]['Location_id']
                        prob_star_list = [
                            float(i) / sum(star_list) for i in star_list
                        ]
                        if len(location_list) > 5:
                            random_location_list = np.random.choice(
                                location_list,
                                len(location_list) / 2,
                                p=prob_star_list,
                                replace=False)
                        else:
                            random_location_list = location_list
                        star_list = cfg.poi_dict[str(item_id)]['stars']

                        star_list = [
                            b[0] for b in sorted(enumerate(star_list),
                                                 key=lambda i: i[1])
                        ]
                        print(star_list)
                        location_stars = star_list[:3]
                        for index in location_stars:
                            l = cfg.poi_dict[str(
                                item_id)]['Location_id'][index]
                            star_location_list.append(l)
                    print('star_location_list: ', star_location_list)
                    print('random_location_list', random_location_list)
                    if location_id in random_location_list:
                        print('Random Combined Rec Success! in episode: {}.'.
                              format(epi_count))
                        success_at_turn_list_location_random[turn_count] += 1
                    else:
                        print('Random Combined Rec failed! in episode: {}.'.
                              format(epi_count))

                    if location_id in star_location_list:
                        print('Random Combined Rec Success! in episode: {}.'.
                              format(epi_count))
                        success_at_turn_list_location_rate[turn_count] += 1
                    else:
                        print('Random Combined Rec failed! in episode: {}.'.
                              format(epi_count))

                else:
                    print('Independent Rec Success! in episode: {}.'.format(
                        epi_count))
                    success_at_turn_list_location_random[turn_count] += 1
                    success_at_turn_list_location_rate[turn_count] += 1

        if A.purpose != 'pretrain':
            if success == True:
                print('Rec Success! in episode: {}.'.format(epi_count))
                success_at_turn_list_item[turn_count] += 1


#
        print('total combined: ', combined_num)
        # update PN model
        if A.playby == 'policy' and A.eval != 1 and A.purpose != 'pretrain':
            update_PN_model(PN_model, log_prob_list, rewards, optimizer)
            print('updated PN model')
            current_length = len(log_prob_list)
            conversation_length_list.append(current_length)
        # end update

        check_span = 10
        if epi_count % check_span == 0 and epi_count >= 3 * check_span and cfg.eval != 1 and A.purpose != 'pretrain':
            PATH = '../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
                A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
                A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask, epi_count)
            torch.save(PN_model.state_dict(), PATH)
            print('Model saved at {}'.format(PATH))

            a1 = conversation_length_list[epi_count -
                                          3 * check_span:epi_count -
                                          2 * check_span]
            a2 = conversation_length_list[epi_count -
                                          2 * check_span:epi_count -
                                          1 * check_span]
            a3 = conversation_length_list[epi_count - 1 * check_span:]
            a1 = np.mean(np.array(a1))
            a2 = np.mean(np.array(a2))
            a3 = np.mean(np.array(a3))

            with open(write_fp, 'a') as f:
                f.write('$$$current turn: {}, a3: {}, a2: {}, a1: {}\n'.format(
                    epi_count, a3, a2, a1))
            print('current turn: {}, a3: {}, a2: {}, a1: {}'.format(
                epi_count, a3, a2, a1))

            num_interval = int(epi_count / check_span)
            for i in range(num_interval):
                ave = np.mean(
                    np.array(conversation_length_list[i * check_span:(i + 1) *
                                                      check_span]))
                print('start: {}, end: {}, average: {}'.format(
                    i * check_span, (i + 1) * check_span, ave))
                PATH = '../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                    A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
                    A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval,
                    A.initeval, A.mini, A.alwaysupdate, A.upcount, A.upreg,
                    A.mask, (i + 1) * check_span)
                print('Model saved at: {}'.format(PATH))

            if a3 > a1 and a3 > a2:
                print('Early stop of RL!')
                exit()

        # write control information
        if A.purpose != 'pretrain':
            with open(write_fp, 'a') as f:
                f.write('Big features are: {}\n'.format(choose_pool))
                if rewards is not None:
                    f.write('reward is: {}\n'.format(
                        rewards.data.numpy().tolist()))
                f.write(
                    'WHOLE PROCESS TAKES: {} SECONDS\n'.format(time.time() -
                                                               start))
        # end write

        # Write to pretrain numpy which is the pretrain data.
        if A.purpose == 'pretrain':
            if len(numpy_list) > 5000:
                with open(
                        '../data/pretrain-numpy-data-{}/segment-{}-start-{}-end-{}.pk'
                        .format(A.mod, NUMPY_COUNT, A.startFrom,
                                A.endAt), 'wb') as f:
                    pickle.dump(numpy_list, f)
                    print('Have written 5000 numpy arrays!')
                NUMPY_COUNT += 1
                numpy_list = list()
        # end write

    print('\nLocation result:')
    sum_sr = 0.0
    for i in range(len(success_at_turn_list_location_rate)):
        success_rate = success_at_turn_list_location_rate[i] / A.endAt
        sum_sr += success_rate
        print('success rate is {} at turn {}, accumulated sum is {}'.format(
            success_rate, i + 1, sum_sr))
    print('Average turn is: ', total_turn / A.endAt + 1)
Beispiel #3
0
def main(epoch):
    parser = argparse.ArgumentParser(
        description="Run conversational recommendation.")
    parser.add_argument('-mt', type=int, dest='mt', help='MAX_TURN', default=5)
    parser.add_argument('-playby',
                        type=str,
                        dest='playby',
                        help='playby',
                        default='policy')
    # options include:
    # AO: (Ask Only and recommend by probability)
    # RO: (Recommend Only)
    # policy: (action decided by our policy network)
    parser.add_argument('-fmCommand',
                        type=str,
                        dest='fmCommand',
                        help='fmCommand',
                        default=8)
    # the command used for FM, check out /EAR/lastfm/FM/
    parser.add_argument('-optim',
                        type=str,
                        dest='optim',
                        help='optimizer',
                        default='SGD')
    # the optimizer for policy network
    parser.add_argument('-lr', type=float, dest='lr', help='lr', default=0.001)
    # learning rate of policy network
    parser.add_argument('-decay',
                        type=float,
                        dest='decay',
                        help='decay',
                        default=0)
    # weight decay
    parser.add_argument('-TopKTaxo',
                        type=int,
                        dest='TopKTaxo',
                        help='TopKTaxo',
                        default=3)
    # how many 2-layer feature will represent a big feature. Only Yelp dataset use this param, lastFM have no effect.
    parser.add_argument('-gamma',
                        type=float,
                        dest='gamma',
                        help='gamma',
                        default=0.7)
    # gamma of training policy network
    parser.add_argument('-trick',
                        type=int,
                        dest='trick',
                        help='trick',
                        default=0)
    # whether use normalization in training policy network
    parser.add_argument('-startFrom',
                        type=int,
                        dest='startFrom',
                        help='startFrom',
                        default=0)  # 85817
    # startFrom which user-item interaction pair
    parser.add_argument('-endAt',
                        type=int,
                        dest='endAt',
                        help='endAt',
                        default=20000)
    # endAt which user-item interaction pair
    parser.add_argument('-strategy',
                        type=str,
                        dest='strategy',
                        help='strategy',
                        default='maxent')
    # strategy to choose question to ask, only have effect
    parser.add_argument('-eval', type=int, dest='eval', help='eval', default=0)
    # whether current run is for evaluation
    parser.add_argument('-mini', type=int, dest='mini', help='mini', default=0)
    # means `mini`-batch update the FM
    parser.add_argument('-alwaysupdate',
                        type=int,
                        dest='alwaysupdate',
                        help='alwaysupdate',
                        default=0)
    # means always mini-batch update the FM, alternative is that only do the update for 1 time in a session.
    # we leave this exploration tof follower of our work.
    parser.add_argument('-initeval',
                        type=int,
                        dest='initeval',
                        help='initeval',
                        default=0)
    # whether do the evaluation for the `init`ial version of policy network (directly after pre-train,default=)
    parser.add_argument('-upoptim',
                        type=str,
                        dest='upoptim',
                        help='upoptim',
                        default='SGD')
    # optimizer for reflection stafe
    parser.add_argument('-upcount',
                        type=int,
                        dest='upcount',
                        help='upcount',
                        default=0)
    # how many times to do reflection
    parser.add_argument('-upreg',
                        type=float,
                        dest='upreg',
                        help='upreg',
                        default=0.001)
    # regularization term in
    parser.add_argument('-code',
                        type=str,
                        dest='code',
                        help='code',
                        default='stable')
    # We use it to give each run a unique identifier.
    parser.add_argument('-purpose',
                        type=str,
                        dest='purpose',
                        help='purpose',
                        default='train')
    # options: pretrain, others
    parser.add_argument('-mod',
                        type=str,
                        dest='mod',
                        help='mod',
                        default='ear')
    # options: CRM, EAR
    parser.add_argument('-mask', type=int, dest='mask', help='mask', default=0)
    # use for ablation study, 1, 2, 3, 4 represent our four segments, {ent, sim, his, len}

    A = parser.parse_args()

    cfg.change_param(playby=A.playby,
                     eval=A.eval,
                     update_count=A.upcount,
                     update_reg=A.upreg,
                     purpose=A.purpose,
                     mod=A.mod,
                     mask=A.mask)

    random.seed(1)

    # we random shuffle and split the valid and test set, for Action Stage training and evaluation respectively, to avoid the bias in the dataset.
    all_list = cfg.valid_list + cfg.test_list
    print('The length of all list is: {}'.format(len(all_list)))
    random.shuffle(all_list)
    the_valid_list = all_list[:int(len(all_list) / 2.0)]
    the_test_list = all_list[int(len(all_list) / 2.0):]
    # the_valid_list = cfg.valid_list
    # the_test_list = cfg.test_list

    gamma = A.gamma
    FM_model = cfg.FM_model

    if A.mod == 'ear':
        # fp = '../../data/PN-model-crm/PN-model-crm.txt'
        if epoch == 0 and cfg.eval == 0:
            fp = '../../data/PN-model-ear/pretrain-model.pt'
        else:
            fp = '../../data/PN-model-ear/model-epoch0'
    INPUT_DIM = 0
    if A.mod == 'ear':
        INPUT_DIM = len(cfg.tag_map) * 2 + cfg.MAX_TURN + 8
    if A.mod == 'crm':
        INPUT_DIM = 4382
    PN_model = PolicyNetwork(input_dim=INPUT_DIM,
                             dim1=1500,
                             output_dim=len(cfg.tag_map) + 1)
    start = time.time()

    try:
        print('fp is: {}'.format(fp))
        PN_model.load_state_dict(torch.load(fp))
        print('load PN model success. ')
    except:
        print('Cannot load the model!!!!!!!!!\nfp is: {}'.format(fp))
        # if A.playby == 'policy':
        #     sys.exit()

    if A.optim == 'Adam':
        optimizer = torch.optim.Adam(PN_model.parameters(),
                                     lr=A.lr,
                                     weight_decay=A.decay)
    if A.optim == 'SGD':
        optimizer = torch.optim.SGD(PN_model.parameters(),
                                    lr=A.lr,
                                    weight_decay=A.decay)
    if A.optim == 'RMS':
        optimizer = torch.optim.RMSprop(PN_model.parameters(),
                                        lr=A.lr,
                                        weight_decay=A.decay)

    numpy_list = list()
    NUMPY_COUNT = 0

    sample_dict = defaultdict(list)
    conversation_length_list = list()
    # endAt = len(the_valid_list) if cfg.eval == 0 else len(the_test_list)
    endAt = len(the_valid_list)
    print(f'endAt: {endAt}')
    print('-' * 10)
    print('Train mode' if cfg.eval == 0 else 'Test mode')
    print('-' * 10)
    for epi_count in range(A.startFrom, endAt):
        if epi_count % 1 == 0:
            print('-----\nEpoch: {}\tIt has processed {}/{} episodes'.format(
                epoch, epi_count, endAt))
        start = time.time()
        cfg.actionProb = epi_count / endAt

        # if A.test == 1 or A.eval == 1:
        if A.eval == 1:
            # u, item = the_test_list[epi_count]
            u, item = the_valid_list[epi_count]
        else:
            u, item = the_valid_list[epi_count]

        if A.purpose == 'fmdata':
            u, item = 0, epi_count

        if A.purpose == 'pretrain':
            u, item = cfg.train_list[epi_count]

        current_FM_model = copy.deepcopy(FM_model)
        param1, param2 = list(), list()
        param3 = list()
        param4 = list()
        i = 0
        for name, param in current_FM_model.named_parameters():
            param4.append(param)
            # print(name, param)
            if i == 0:
                param1.append(param)
            else:
                param2.append(param)
            if i == 2:
                param3.append(param)
            i += 1
        optimizer1_fm = torch.optim.Adagrad(param1,
                                            lr=0.01,
                                            weight_decay=A.decay)
        optimizer2_fm = torch.optim.SGD(param4, lr=0.001, weight_decay=A.decay)

        user_id = int(u)
        item_id = int(item)

        write_fp = '../../data/interaction-log/{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma,
            A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
            A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask)

        choose_pool = cfg.item_dict[str(item_id)]['categories']

        if A.purpose not in ['pretrain', 'fmdata']:
            # this means that: we are not collecting data for pretraining or fm data
            # then we only randomly choose one start attribute to ask!
            choose_pool = [random.choice(choose_pool)]
        # for item_id in cfg.item_dict:
        #     choose_pool = [k for k in cfg.item_dict_rel[str(item_id)] if len(cfg.item_dict_rel[str(item_id)][k]) != 0]
        #     if choose_pool == None:
        #         print(item_id)
        print(f'user id: {user_id}\titem id: {item_id}')
        # choose_pool = [k for k in cfg.item_dict_rel[str(item_id)] if len(cfg.item_dict_rel[str(item_id)][k]) != 0]
        # choose_pool = random.choice(choose_pool)
        for c in choose_pool:
            with open(write_fp, 'a+') as f:
                f.write(
                    'Starting new\nuser ID: {}, item ID: {} episode count: {}, feature: {}\n'
                    .format(user_id, item_id, epi_count,
                            cfg.item_dict[str(item_id)]['categories']))
            start_facet = c
            if A.purpose != 'pretrain':
                log_prob_list, rewards, hl = run_one_episode(
                    current_FM_model, user_id, item_id, A.mt, False, write_fp,
                    A.strategy, A.TopKTaxo, PN_model, gamma, A.trick, A.mini,
                    optimizer1_fm, optimizer2_fm, A.alwaysupdate, start_facet,
                    A.mask, sample_dict)
                if cfg.eval == 0:
                    with open(f'../../data/train_rec{epoch}.txt', 'a+') as f:
                        # f.writelines(str(rewards.tolist()))
                        f.writelines(str(hl))
                        f.writelines('\n')
                else:
                    with open('../../data/test_rec.txt', 'a+') as f:
                        # f.writelines(str(rewards.tolist()))
                        f.writelines(str(hl))
                        f.writelines('\n')
            else:
                current_np = run_one_episode(
                    current_FM_model, user_id, item_id, A.mt, False, write_fp,
                    A.strategy, A.TopKTaxo, PN_model, gamma, A.trick, A.mini,
                    optimizer1_fm, optimizer2_fm, A.alwaysupdate, start_facet,
                    A.mask, sample_dict)
                numpy_list += current_np

            # update PN model
            if A.playby == 'policy' and A.eval != 1:
                update_PN_model(PN_model, log_prob_list, rewards, optimizer)
                print('updated PN model')
                current_length = len(log_prob_list)
                conversation_length_list.append(current_length)
            # end update

            if A.purpose != 'pretrain':
                with open(write_fp, 'a') as f:
                    f.write('Big features are: {}\n'.format(choose_pool))
                    if rewards is not None:
                        f.write('reward is: {}\n'.format(
                            rewards.data.numpy().tolist()))
                    f.write('WHOLE PROCESS TAKES: {} SECONDS\n'.format(
                        time.time() - start))

        # Write to pretrain numpy.
        if A.purpose == 'pretrain':
            if len(numpy_list) > 5000:
                with open(
                        '../../data/pretrain-numpy-data-{}/segment-{}-start-{}-end-{}.pk'
                        .format(A.mod, NUMPY_COUNT, A.startFrom,
                                A.endAt), 'wb') as f:
                    pickle.dump(numpy_list, f)
                    print('Have written 5000 numpy arrays!')
                NUMPY_COUNT += 1
                numpy_list = list()
        # numpy_list is a list of list.
        # e.g. numpy_list[0][0]: int, indicating the action.
        # numpy_list[0][1]: a one-d array of length 89 for EAR, and 33 for CRM.
        # end write

        # Write sample dict:
        if A.purpose == 'fmdata' and A.playby != 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/start-{}-end-{}.json'.format(
                            A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)
        # end write
        if A.purpose == 'fmdata' and A.playby == 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/valid-start-{}-end-{}.json'.
                        format(A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)

        check_span = 500
        if epi_count % check_span == 0 and epi_count >= 3 * check_span and cfg.eval != 1 and A.purpose != 'pretrain':
            # We use AT (average turn of conversation) as our stopping criterion
            # in training mode, save RL model periodically
            # save model first
            # PATH = '../../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            #     A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma, A.playby, A.strategy, A.TopKTaxo, A.trick,
            #     A.eval, A.initeval,
            #     A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask, epi_count)
            PATH = f'../../data/PN-model-{A.mod}/model-epoch{epoch}'
            torch.save(PN_model.state_dict(), PATH)
            print('Model saved at {}'.format(PATH))

            # a0 = conversation_length_list[epi_count - 4 * check_span: epi_count - 3 * check_span]
            a1 = conversation_length_list[epi_count -
                                          3 * check_span:epi_count -
                                          2 * check_span]
            a2 = conversation_length_list[epi_count -
                                          2 * check_span:epi_count -
                                          1 * check_span]
            a3 = conversation_length_list[epi_count - 1 * check_span:]
            a1 = np.mean(np.array(a1))
            a2 = np.mean(np.array(a2))
            a3 = np.mean(np.array(a3))

            with open(write_fp, 'a') as f:
                f.write('$$$current turn: {}, a3: {}, a2: {}, a1: {}\n'.format(
                    epi_count, a3, a2, a1))
            print('current turn: {}, a3: {}, a2: {}, a1: {}'.format(
                epi_count, a3, a2, a1))

            num_interval = int(epi_count / check_span)
            for i in range(num_interval):
                ave = np.mean(
                    np.array(conversation_length_list[i * check_span:(i + 1) *
                                                      check_span]))
                print('start: {}, end: {}, average: {}'.format(
                    i * check_span, (i + 1) * check_span, ave))
                # PATH = '../../data/PN-model-{}/v4-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
                #     A.mod.lower(), A.code, A.startFrom, A.endAt, A.lr, A.gamma, A.playby, A.strategy, A.TopKTaxo,
                #     A.trick,
                #     A.eval, A.initeval,
                #     A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask, (i + 1) * check_span)
                PATH = f'../../data/PN-model-{A.mod}/model-epoch{epoch}'
                print('Model saved at: {}'.format(PATH))

            if a3 > a1 and a3 > a2:
                print('Early stop of RL!')
                if cfg.eval == 1:
                    exit()
                else:
                    return
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser(
        description="Run conversational recommendation.")
    parser.add_argument('-mt', type=int, dest='mt', help='MAX_TURN')
    parser.add_argument('-playby', type=str, dest='playby', help='playby')
    # options include:
    # AO: (Ask Only and recommend by probability)
    # RO: (Recommend Only)
    # policy: (action decided by our policy network)
    # sac: (action decided by our SAC network)
    parser.add_argument('-fmCommand',
                        type=str,
                        dest='fmCommand',
                        help='fmCommand')
    # the command used for FM, check out /EAR/lastfm/FM/
    parser.add_argument('-optim', type=str, dest='optim', help='optimizer')
    # the optimizer for policy network
    parser.add_argument('-actor_lr',
                        type=float,
                        dest='actor_lr',
                        help='actor learning rate')
    # learning rate of Actor network
    parser.add_argument('-critic_lr',
                        type=float,
                        dest='critic_lr',
                        help='critic learning rate')
    # learning rate of the Critic network
    parser.add_argument('-actor_decay',
                        type=float,
                        dest='actor_decay',
                        help='actor weight decay')
    parser.add_argument('-decay',
                        type=float,
                        dest='decay',
                        help="weight decay for FM model")
    # weight decay
    parser.add_argument('-critic_decay',
                        type=float,
                        dest='critic_decay',
                        help='critic weight decay')
    parser.add_argument('-TopKTaxo',
                        type=int,
                        dest='TopKTaxo',
                        help='TopKTaxo')
    # how many 2-layer feature will represent a big feature. Only Yelp dataset use this param, lastFM have no effect.
    parser.add_argument('-gamma', type=float, dest='gamma', help='gamma')
    # gamma of training policy network
    parser.add_argument('-trick', type=int, dest='trick', help='trick')
    # whether use normalization in training policy network
    parser.add_argument('-startFrom',
                        type=int,
                        dest='startFrom',
                        help='startFrom')
    # startFrom which user-item interaction pair
    parser.add_argument('-endAt', type=int, dest='endAt', help='endAt')
    # endAt which user-item interaction pair
    parser.add_argument('-strategy',
                        type=str,
                        dest='strategy',
                        help='strategy')
    # strategy to choose question to ask, only have effect
    parser.add_argument('-eval', type=int, dest='eval', help='eval')
    # whether current run is for evaluation
    parser.add_argument('-mini', type=int, dest='mini', help='mini')
    # means `mini`-batch update the FM
    parser.add_argument('-alwaysupdate',
                        type=int,
                        dest='alwaysupdate',
                        help='alwaysupdate')
    # means always mini-batch update the FM, alternative is that only do the update for 1 time in a session.
    # we leave this exploration tof follower of our work.
    parser.add_argument('-initeval',
                        type=int,
                        dest='initeval',
                        help='initeval')
    # whether do the evaluation for the `init`ial version of policy network (directly after pre-train)
    parser.add_argument('-upoptim', type=str, dest='upoptim', help='upoptim')
    # optimizer for reflection stafe
    parser.add_argument('-upcount', type=int, dest='upcount', help='upcount')
    # how many times to do reflection
    parser.add_argument('-upreg', type=float, dest='upreg', help='upreg')
    # regularization term in
    parser.add_argument('-code', type=str, dest='code', help='code')
    # We use it to give each run a unique identifier.
    parser.add_argument('-purpose', type=str, dest='purpose', help='purpose')
    # options: pretrain, others
    parser.add_argument('-mod', type=str, dest='mod', help='mod')
    # options: CRM, EAR
    parser.add_argument('-mask', type=int, dest='mask', help='mask')
    # use for ablation study, 1, 2, 3, 4 represent our four segments, {ent, sim, his, len}
    parser.add_argument('-use_sac', type=bool, dest='use_sac', help='use_sac')
    # true if the RL module uses SAC

    A = parser.parse_args()

    cfg.change_param(playby=A.playby,
                     eval=A.eval,
                     update_count=A.upcount,
                     update_reg=A.upreg,
                     purpose=A.purpose,
                     mod=A.mod,
                     mask=A.mask)

    random.seed(1)

    # we random shuffle and split the valid and test set, for Action Stage training and evaluation respectively, to avoid the bias in the dataset.
    all_list = cfg.valid_list + cfg.test_list
    print('The length of all list is: {}'.format(len(all_list)))
    random.shuffle(all_list)
    the_valid_list = all_list[:int(len(all_list) / 2.0)]
    the_test_list = all_list[int(len(all_list) / 2.0):]

    gamma = A.gamma
    FM_model = cfg.FM_model

    if A.eval == 1:
        if A.mod == 'ear':
            fp = '../../data/PN-model-ear/PN-model-ear.txt'
        if A.mod == 'crm':
            fp = '../../data/PN-model-crm/PN-model-crm.txt'
        if A.initeval == 1:
            if A.mod == 'ear':
                fp = '../../data/PN-model-ear/pretrain-model.pt'
            if A.mod == 'crm':
                fp = '../../data/PN-model-crm/pretrain-model.pt'
    else:
        # means training
        if A.mod == 'ear':
            fp = '../../data/PN-model-ear/pretrain-model.pt'
        if A.mod == 'crm':
            fp = '../../data/PN-model-crm/pretrain-model.pt'
    INPUT_DIM = 0
    if A.mod == 'ear':
        INPUT_DIM = 81
    if A.mod == 'crm':
        INPUT_DIM = 590
    print('fp is: {}'.format(fp))

    # Initialie the policy network to either PolicyNetwork or SAC-Net
    if not A.use_sac:
        PN_model = PolicyNetwork(input_dim=INPUT_DIM, dim1=64, output_dim=34)
        start = time.time()

        try:
            PN_model.load_state_dict(torch.load(fp))
            print('Now Load PN pretrain from {}, takes {} seconds.'.format(
                fp,
                time.time() - start))
        except:
            print('Cannot load the model!!!!!!!!!\n fp is: {}'.format(fp))
            if A.playby == 'policy':
                sys.exit()

        if A.optim == 'Adam':
            optimizer = torch.optim.Adam(PN_model.parameters(),
                                         lr=A.lr,
                                         weight_decay=A.decay)
        if A.optim == 'SGD':
            optimizer = torch.optim.SGD(PN_model.parameters(),
                                        lr=A.lr,
                                        weight_decay=A.decay)
        if A.optim == 'RMS':
            optimizer = torch.optim.RMSprop(PN_model.parameters(),
                                            lr=A.lr,
                                            weight_decay=A.decay)

    else:
        PN_model = SAC_Net(input_dim=INPUT_DIM,
                           dim1=64,
                           output_dim=34,
                           actor_lr=A.actor_lr,
                           critic_lr=A.critic_lr,
                           discount_rate=gamma,
                           actor_w_decay=A.actor_decay,
                           critic_w_decay=A.critic_decay)

    numpy_list = list()
    rewards_list = list()
    NUMPY_COUNT = 0

    sample_dict = defaultdict(list)
    conversation_length_list = list()
    for epi_count in range(A.startFrom, A.endAt):
        if epi_count % 1 == 0:
            print('-----\nIt has processed {} episodes'.format(epi_count))
        start = time.time()

        u, item, r = the_valid_list[epi_count]

        # if A.test == 1 or A.eval == 1:
        if A.eval == 1:
            u, item, r = the_test_list[epi_count]

        if A.purpose == 'fmdata':
            u, item = 0, epi_count

        if A.purpose == 'pretrain':
            u, item, r = cfg.train_list[epi_count]

        current_FM_model = copy.deepcopy(FM_model)
        param1, param2 = list(), list()
        i = 0
        for name, param in current_FM_model.named_parameters():
            # print(name, param)
            if i == 0:
                param1.append(param)
            else:
                param2.append(param)
            i += 1

        optimizer1_fm, optimizer2_fm = None, None
        if A.purpose != 'fmdata':
            optimizer1_fm = torch.optim.Adagrad(param1,
                                                lr=0.01,
                                                weight_decay=A.decay)
            if A.upoptim == 'Ada':
                optimizer2_fm = torch.optim.Adagrad(param2,
                                                    lr=0.01,
                                                    weight_decay=A.decay)
            if A.upoptim == 'SGD':
                optimizer2_fm = torch.optim.SGD(param2,
                                                lr=0.001,
                                                weight_decay=A.decay)

        user_id = int(u)
        item_id = int(item)

        big_feature_list = list()

        for k, v in cfg.taxo_dict.items():
            if len(
                    set(v).intersection(
                        set(cfg.item_dict[str(item_id)]['categories']))) > 0:
                big_feature_list.append(k)

        write_fp = '../../data/interaction-log/{}/v5-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}.txt'.format(
            A.mod.lower(), A.code, A.startFrom, A.endAt, A.actor_lr, A.gamma,
            A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval, A.initeval,
            A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask)

        if cfg.item_dict[str(item_id)]['RestaurantsPriceRange2'] is not None:
            choose_pool = ['stars', 'city', 'RestaurantsPriceRange2'
                           ] + big_feature_list
        else:
            choose_pool = ['stars', 'city'] + big_feature_list

        choose_pool_original = choose_pool

        if A.purpose not in ['pretrain', 'fmdata']:
            # this means that: we are not collecting data for pretraining or fm data
            # then we only randomly choose one start attribute to ask!
            choose_pool = [random.choice(choose_pool)]

        for c in choose_pool:
            with open(write_fp, 'a') as f:
                f.write(
                    'Starting new\nuser ID: {}, item ID: {} episode count: {}, feature: {}\n'
                    .format(user_id, item_id, epi_count,
                            cfg.item_dict[str(item_id)]['categories']))
            start_facet = c
            if A.purpose != 'pretrain' and A.playby != 'sac':
                log_prob_list, rewards = run_one_episode(
                    current_FM_model, user_id, item_id, A.mt, False, write_fp,
                    A.strategy, A.TopKTaxo, PN_model, gamma, A.trick, A.mini,
                    optimizer1_fm, optimizer2_fm, A.alwaysupdate, start_facet,
                    A.mask, sample_dict, choose_pool_original)
            else:
                if A.playby != 'sac':
                    current_np = run_one_episode(
                        current_FM_model, user_id, item_id, A.mt, False,
                        write_fp, A.strategy, A.TopKTaxo, PN_model, gamma,
                        A.trick, A.mini, optimizer1_fm, optimizer2_fm,
                        A.alwaysupdate, start_facet, A.mask, sample_dict,
                        choose_pool_original)
                    numpy_list += current_np

                else:
                    current_np, current_reward = run_one_episode(
                        current_FM_model, user_id, item_id, A.mt, False,
                        write_fp, A.strategy, A.TopKTaxo, PN_model, gamma,
                        A.trick, A.mini, optimizer1_fm, optimizer2_fm,
                        A.alwaysupdate, start_facet, A.mask, sample_dict,
                        choose_pool_original)
                    rewards_list += current_reward
                    numpy_list += current_np

            # update PN model
            if A.playby == 'policy' and A.eval != 1:
                update_PN_model(PN_model, log_prob_list, rewards, optimizer)
                print('updated PN model')
                current_length = len(log_prob_list)
                conversation_length_list.append(current_length)
            # end update
            rewards = current_reward
            # Update SAC
            if A.purpose != 'pretrain':
                with open(write_fp, 'a') as f:
                    f.write('Big features are: {}\n'.format(choose_pool))
                    if rewards is not None:
                        f.write('reward is: {}\n'.format(
                            rewards.data.numpy().tolist()))
                    f.write('WHOLE PROCESS TAKES: {} SECONDS\n'.format(
                        time.time() - start))

        # Write to pretrain numpy.
        if A.purpose == 'pretrain':
            if A.playby != 'sac':
                if len(numpy_list) > 5000:
                    with open(
                            '../../data/pretrain-numpy-data-{}/segment-{}-start-{}-end-{}.pk'
                            .format(A.mod, NUMPY_COUNT, A.startFrom,
                                    A.endAt), 'wb') as f:
                        pickle.dump(numpy_list, f)
                        print('Have written 5000 numpy arrays!')
                    NUMPY_COUNT += 1
                    numpy_list = list()
            else:
                # In SAC mode, collect both numpy_list and rewards_list as training data
                if len(numpy_list) > 5000 or len(rewards_list) > 5000:
                    assert len(rewards_list) == len(
                        numpy_list
                    ), "rewards and state-action pairs have different size!"
                    directory = '../../data/pretrain-sac-numpy-data-{}/segment-{}-start-{}-end-{}.pk'.format(
                        A.mod, NUMPY_COUNT, A.startFrom, A.endAt)
                    rewards_directory = '../../data/pretrain-sac-reward-data-{}/segment-{}-start-{}-end-{}.pk'.format(
                        A.mod, NUMPY_COUNT, A.startFrom, A.endAt)
                    with open(directory, 'wb') as f:
                        pickle.dump(numpy_list, f)
                        print('Have written 5000 numpy arrays for SAC!')
                    with open(rewards_directory, 'wb') as f:
                        pickle.dump(rewards_list, f)
                        print('Have written 5000 rewrds for SAC!')
                    NUMPY_COUNT += 1
                    numpy_list = list()
                    rewards_list = list()

        # numpy_list is a list of list.
        # e.g. numpy_list[0][0]: int, indicating the action.
        # numpy_list[0][1]: a one-d array of length 89 for EAR, and 33 for CRM.
        # end write

        # Write sample dict:
        if A.purpose == 'fmdata' and A.playby != 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/start-{}-end-{}.json'.format(
                            A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)
        # end write
        if A.purpose == 'fmdata' and A.playby == 'AOO_valid':
            if epi_count % 100 == 1:
                with open(
                        '../../data/sample-dict/valid-start-{}-end-{}.json'.
                        format(A.startFrom, A.endAt), 'w') as f:
                    json.dump(sample_dict, f, indent=4)

        check_span = 500
        if epi_count % check_span == 0 and epi_count >= 3 * check_span and cfg.eval != 1 and A.purpose != 'pretrain':
            # We use AT (average turn of conversation) as our stopping criterion
            # in training mode, save RL model periodically
            # save model first
            PATH = '../../data/PN-model-{}/v5-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                A.mod.lower(), A.code, A.startFrom, A.endAt, A.actor_lr,
                A.gamma, A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval,
                A.initeval, A.mini, A.alwaysupdate, A.upcount, A.upreg, A.mask,
                epi_count)
            torch.save(PN_model.state_dict(), PATH)
            print('Model saved at {}'.format(PATH))

            # a0 = conversation_length_list[epi_count - 4 * check_span: epi_count - 3 * check_span]
            a1 = conversation_length_list[epi_count -
                                          3 * check_span:epi_count -
                                          2 * check_span]
            a2 = conversation_length_list[epi_count -
                                          2 * check_span:epi_count -
                                          1 * check_span]
            a3 = conversation_length_list[epi_count - 1 * check_span:]
            a1 = np.mean(np.array(a1))
            a2 = np.mean(np.array(a2))
            a3 = np.mean(np.array(a3))

            with open(write_fp, 'a') as f:
                f.write('$$$current turn: {}, a3: {}, a2: {}, a1: {}\n'.format(
                    epi_count, a3, a2, a1))
            print('current turn: {}, a3: {}, a2: {}, a1: {}'.format(
                epi_count, a3, a2, a1))

            num_interval = int(epi_count / check_span)
            for i in range(num_interval):
                ave = np.mean(
                    np.array(conversation_length_list[i * check_span:(i + 1) *
                                                      check_span]))
                print('start: {}, end: {}, average: {}'.format(
                    i * check_span, (i + 1) * check_span, ave))
                PATH = '../../data/PN-model-{}/v5-code-{}-s-{}-e-{}-lr-{}-gamma-{}-playby-{}-stra-{}-topK-{}-trick-{}-eval-{}-init-{}-mini-{}-always-{}-upcount-{}-upreg-{}-m-{}-epi-{}.txt'.format(
                    A.mod.lower(), A.code, A.startFrom, A.endAt, A.actor_lr,
                    A.gamma, A.playby, A.strategy, A.TopKTaxo, A.trick, A.eval,
                    A.initeval, A.mini, A.alwaysupdate, A.upcount, A.upreg,
                    A.mask, (i + 1) * check_span)
                print('Model saved at: {}'.format(PATH))

            if a3 > a1 and a3 > a2:
                print('Early stop of RL!')
                exit()