config.saved_path = saved_path

# save configuration
with open(os.path.join(saved_path, 'config.json'), 'w') as f:
    json.dump(config, f, indent=4)  # sort_keys=True

corpus = corpora.NormMultiWozCorpus(config)
train_dial, val_dial, test_dial = corpus.get_corpus()

train_data = BeliefDbDataLoaders('Train', train_dial, config)
val_data = BeliefDbDataLoaders('Val', val_dial, config)
test_data = BeliefDbDataLoaders('Test', test_dial, config)

evaluator = MultiWozEvaluator('SysWoz')

model = SysPerfectBD2Cat(corpus, config)

if config.use_gpu:
    model.cuda()

if config.use_gpu:
    model.load_state_dict(torch.load(
        '/root/NeuralDialog-LaRL/larl_model/best-model'))
else:
    model.load_state_dict(torch.load(
        '/root/NeuralDialog-LaRL/larl_model/best-model', map_location=lambda storage, loc: storage))

data_feed = {'bs': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0,
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None):
        SysPolicy.__init__(self)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for LaRL is specified!")
            archive_file = cached_path(model_file)

        temp_path = tempfile.mkdtemp()
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.prev_state = init_state()
        self.prev_active_domain = None

        domain_name = 'object_division'
        domain_info = domain.get_domain(domain_name)

        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
        train_data_path = os.path.join(data_path, 'norm-multi-woz', 'train_dials.json')
        if not os.path.exists(train_data_path):
            zipped_file = os.path.join(data_path, 'norm-multi-woz.zip')
            archive = zipfile.ZipFile(zipped_file, 'r')
            archive.extractall(data_path)

        norm_multiwoz_path = os.path.join(data_path, 'norm-multi-woz')
        with open(os.path.join(norm_multiwoz_path, 'input_lang.index2word.json')) as f:
            self.input_lang_index2word = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'input_lang.word2index.json')) as f:
            self.input_lang_word2index = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'output_lang.index2word.json')) as f:
            self.output_lang_index2word = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'output_lang.word2index.json')) as f:
            self.output_lang_word2index = json.load(f)

        config = Pack(
            seed=10,
            train_path=train_data_path,
            max_vocab_size=1000,
            last_n_model=5,
            max_utt_len=50,
            max_dec_len=50,
            backward_size=2,
            batch_size=1,
            use_gpu=True,
            op='adam',
            init_lr=0.001,
            l2_norm=1e-05,
            momentum=0.0,
            grad_clip=5.0,
            dropout=0.5,
            max_epoch=100,
            embed_size=100,
            num_layers=1,
            utt_rnn_cell='gru',
            utt_cell_size=300,
            bi_utt_cell=True,
            enc_use_attn=True,
            dec_use_attn=True,
            dec_rnn_cell='lstm',
            dec_cell_size=300,
            dec_attn_mode='cat',
            y_size=10,
            k_size=20,
            beta=0.001,
            simple_posterior=True,
            contextual_posterior=True,
            use_mi=False,
            use_pr=True,
            use_diversity=False,
            #
            beam_size=20,
            fix_batch=True,
            fix_train_batch=False,
            avg_type='word',
            print_step=300,
            ckpt_step=1416,
            improve_threshold=0.996,
            patient_increase=2.0,
            save_model=True,
            early_stop=False,
            gen_type='greedy',
            preview_batch_num=None,
            k=domain_info.input_length(),
            init_range=0.1,
            pretrain_folder='2019-09-20-21-43-06-sl_cat',
            forward_only=False
        )

        config.use_gpu = config.use_gpu and torch.cuda.is_available()
        self.corpus = corpora_inference.NormMultiWozCorpus(config)
        self.model = SysPerfectBD2Cat(self.corpus, config)
        self.config = config
        if config.use_gpu:
            self.model.load_state_dict(torch.load(
                os.path.join(temp_path, 'larl_model/best-model')))
            self.model.cuda()
        else:
            self.model.load_state_dict(torch.load(os.path.join(
                temp_path, 'larl_model/best-model'), map_location=lambda storage, loc: storage))
        self.model.eval()
        self.dic = pickle.load(
            open(os.path.join(temp_path, 'larl_model/svdic.pkl'), 'rb'))
Exemple #3
0
def main():
    start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                               time.localtime(time.time()))
    print('[START]', start_time, '=' * 30)
    # RL configuration
    env = 'gpu'
    pretrained_folder = '2019-09-07-01-03-54-sl_cat'
    # pretrained_folder = '2019-06-20-22-49-55-sl_cat'
    # pretrained_model_id = 41
    pretrained_model_id = 35

    exp_dir = os.path.join('sys_config_log_model', pretrained_folder,
                           'rl-' + start_time)
    # create exp folder
    if not os.path.exists(exp_dir):
        os.mkdir(exp_dir)

    rl_config = Pack(
        train_path='../data/norm-multi-woz/train_dials.json',
        valid_path='../data/norm-multi-woz/val_dials.json',
        test_path='../data/norm-multi-woz/test_dials.json',
        sv_config_path=os.path.join('sys_config_log_model', pretrained_folder,
                                    'config.json'),
        sv_model_path=os.path.join('sys_config_log_model', pretrained_folder,
                                   '{}-model'.format(pretrained_model_id)),
        rl_config_path=os.path.join(exp_dir, 'rl_config.json'),
        rl_model_path=os.path.join(exp_dir, 'rl_model'),
        ppl_best_model_path=os.path.join(exp_dir, 'ppl_best.model'),
        reward_best_model_path=os.path.join(exp_dir, 'reward_best.model'),
        record_path=exp_dir,
        record_freq=200,
        sv_train_freq=
        0,  # TODO pay attention to main.py, cuz it is also controlled there
        use_gpu=env == 'gpu',
        nepoch=10,
        nepisode=0,
        tune_pi_only=False,
        max_words=100,
        temperature=1.0,
        episode_repeat=1.0,
        rl_lr=0.01,
        momentum=0.0,
        nesterov=False,
        gamma=0.99,
        rl_clip=5.0,
        random_seed=100,
    )

    # save configuration
    with open(rl_config.rl_config_path, 'w') as f:
        json.dump(rl_config, f, indent=4)

    # set random seed
    set_seed(rl_config.random_seed)

    # load previous supervised learning configuration and corpus
    sv_config = Pack(json.load(open(rl_config.sv_config_path)))
    sv_config['dropout'] = 0.0
    sv_config['use_gpu'] = rl_config.use_gpu
    corpus = NormMultiWozCorpus(sv_config)

    # TARGET AGENT
    sys_model = SysPerfectBD2Cat(corpus, sv_config)
    if sv_config.use_gpu:
        sys_model.cuda()
    sys_model.load_state_dict(
        th.load(rl_config.sv_model_path,
                map_location=lambda storage, location: storage))
    sys_model.eval()
    sys = OfflineLatentRlAgent(sys_model,
                               corpus,
                               rl_config,
                               name='System',
                               tune_pi_only=rl_config.tune_pi_only)

    # start RL
    reinforce = OfflineTaskReinforce(sys, corpus, sv_config, sys_model,
                                     rl_config, task_generate)
    reinforce.run()

    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    print('[END]', end_time, '=' * 30)