def _process_dialogue(self, data): new_dlgs = [] all_sent_lens = [] all_dlg_lens = [] for key, raw_dlg in data.items(): norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0] * self.bs_size, db=[0.0] * self.db_size)] for t_id in range(len(raw_dlg['db'])): usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS] sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS] norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) all_sent_lens.extend([len(usr_utt), len(sys_utt)]) # To stop dialog norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0] * self.bs_size, db=[0.0] * self.db_size)) # if self.config.to_learn == 'usr': # norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) all_dlg_lens.append(len(raw_dlg['db'])) processed_goal = self._process_goal(raw_dlg['goal']) new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key)) self.logger.info('Max utt len = %d, mean utt len = %.2f' % ( np.max(all_sent_lens), float(np.mean(all_sent_lens)))) self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % ( np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) return new_dlgs
def forward(self, data_feed, mode, clf=False, gen_type='greedy', return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var( self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder( short_ctx_utts.unsqueeze(1)) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # pack attention context if self.config.dec_use_attn: attn_context = enc_outs else: attn_context = None # create decoder initial states dec_init_state = self.policy( th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) # decode if self.config.dec_rnn_cell == 'lstm': # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0) dec_init_state = tuple([dec_init_state, dec_init_state]) dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # (batch_size, goal_nhid) if mode == GEN: return ret_dict, labels if return_latent: return Pack(nll=self.nll(dec_outputs, labels), latent_action=dec_init_state) else: return Pack(nll=self.nll(dec_outputs, labels))
def _to_id_corpus(self, name, data): results = [] for dlg in data: if len(dlg.dlg) < 1: continue id_dlg = [] for turn in dlg.dlg: id_turn = Pack(utt=self._sent2id(turn.utt), speaker=turn.speaker, db=turn.db, bs=turn.bs) id_dlg.append(id_turn) id_goal = self._goal2id(dlg.goal) results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key)) return results
def _to_id_corpus(self, name, data): results = [] for dlg in data: if len(dlg.dlg) < 1: continue id_dlg = [] for turn in dlg.dlg: id_turn = Pack(utt=self._sent2id(turn.utt), speaker=turn.speaker) id_dlg.append(id_turn) id_goal = self._goal2id(dlg.goal) id_out = self._outcome2id(dlg.out) results.append(Pack(dlg=id_dlg, goal=id_goal, out=id_out)) return results
def flatten_dialog(self, data, backward_size): results = [] indexes = [] batch_indexes = [] resp_set = set() for dlg in data: goal = dlg.goal key = dlg.key batch_index = [] for i in range(1, len(dlg.dlg)): if dlg.dlg[i].speaker == USR: continue e_idx = i s_idx = max(0, e_idx - backward_size) response = dlg.dlg[i].copy() response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) resp_set.add(json.dumps(response.utt)) context = [] for turn in dlg.dlg[s_idx: e_idx]: turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) context.append(turn) results.append(Pack(context=context, response=response, goal=goal, key=key)) indexes.append(len(indexes)) batch_index.append(indexes[-1]) if len(batch_index) > 0: batch_indexes.append(batch_index) print("Unique resp {}".format(len(resp_set))) return results, indexes, batch_indexes
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # create decoder initial states enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) # create decoder initial states if self.simple_posterior: q_mu, q_logvar = self.c2z(enc_last) sample_z = self.gauss_connector(q_mu, q_logvar) p_mu, p_logvar = self.zero, self.zero else: p_mu, p_logvar = self.c2z(enc_last) # encode response and use posterior to find q(z|x, c) x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) # use prior at inference time, otherwise use posterior if mode == GEN or use_py: sample_z = self.gauss_connector(p_mu, p_logvar) else: sample_z = self.gauss_connector(q_mu, q_logvar) # pack attention context dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, dec_inputs=dec_inputs, dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # (batch_size, goal_nhid) if mode == GEN: ret_dict['sample_z'] = sample_z return ret_dict, labels else: result = Pack(nll=self.nll(dec_outputs, labels)) pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) result['pi_kl'] = pi_kl result['nll'] = self.nll(dec_outputs, labels) return result
def _prepare_batch(self, selected_index): rows = [self.data[idx] for idx in selected_index] ctx_utts, ctx_lens = [], [] out_utts, out_lens = [], [] goals, goal_lens = [], [] for row in rows: in_row, out_row, goal_row = row.context, row.response, row.goal # source context batch_ctx = [] for turn in in_row: batch_ctx.append( self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) ctx_utts.append(batch_ctx) ctx_lens.append(len(batch_ctx)) # target response out_utt = [t for idx, t in enumerate(out_row.utt)] out_utts.append(out_utt) out_lens.append(len(out_utt)) # goal goals.append(goal_row) goal_lens.append(len(goal_row)) vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns max_ctx_len = np.max(vec_ctx_lens) vec_ctx_utts = np.zeros( (self.batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) # confs is used to add some hand-crafted features vec_ctx_confs = np.ones((self.batch_size, max_ctx_len), dtype=np.float32) vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens max_out_len = np.max(vec_out_lens) vec_out_utts = np.zeros((self.batch_size, max_out_len), dtype=np.int32) max_goal_len, min_goal_len = max(goal_lens), min(goal_lens) if max_goal_len != min_goal_len or max_goal_len != 6: print('FATAL ERROR!') exit(-1) self.goal_len = max_goal_len vec_goals = np.zeros((self.batch_size, self.goal_len), dtype=np.int32) for b_id in range(self.batch_size): vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] vec_goals[b_id, :] = goals[b_id] return Pack(context_lens=vec_ctx_lens, \ contexts=vec_ctx_utts, \ context_confs=vec_ctx_confs, \ output_lens=vec_out_lens, \ outputs=vec_out_utts, \ goals=vec_goals)
def transform(token_list): usr, sys = [], [] ptr = 0 while ptr < len(token_list): turn_ptr = ptr turn_list = [] while True: cur_token = token_list[turn_ptr] turn_list.append(cur_token) turn_ptr += 1 if cur_token == EOS: ptr = turn_ptr break all_sent_lens.append(len(turn_list)) if turn_list[0] == USR: usr.append(Pack(utt=turn_list, speaker=USR)) elif turn_list[0] == SYS: sys.append(Pack(utt=turn_list, speaker=SYS)) else: raise ValueError('Invalid speaker') all_dlg_lens.append(len(usr) + len(sys)) return usr, sys
def flatten_dialog(self, data, backward_size): results = [] for dlg in data: goal = dlg.goal for i in range(1, len(dlg.dlg)): if dlg.dlg[i].speaker == USR: continue e_idx = i s_idx = max(0, e_idx - backward_size) response = dlg.dlg[i].copy() response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) context = [] for turn in dlg.dlg[s_idx: e_idx]: turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) context.append(turn) results.append(Pack(context=context, response=response, goal=goal)) return results
def prepare_batch_gen(rows, config): domains = [ 'hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police', 'taxi' ] ctx_utts, ctx_lens = [], [] out_utts, out_lens = [], [] out_bs, out_db = [], [] goals, goal_lens = [], [[] for _ in range(len(domains))] keys = [] for row in rows: in_row, out_row = row['context'], row['response'] # source context batch_ctx = [] for turn in in_row: batch_ctx.append( pad_to(config.max_utt_len, turn['utt'], do_pad=True)) ctx_utts.append(batch_ctx) ctx_lens.append(len(batch_ctx)) out_bs.append(out_row['bs']) out_db.append(out_row['db']) batch_size = len(ctx_lens) vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns max_ctx_len = np.max(vec_ctx_lens) vec_ctx_utts = np.zeros((batch_size, max_ctx_len, config.max_utt_len), dtype=np.int32) vec_out_bs = np.array(out_bs) # (batch_size, 94) vec_out_db = np.array(out_db) # (batch_size, 30) for b_id in range(batch_size): vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] return Pack( context_lens=vec_ctx_lens, # (batch_size, ) # (batch_size, max_ctx_len, max_utt_len) contexts=vec_ctx_utts, bs=vec_out_bs, # (batch_size, 94) db=vec_out_db # (batch_size, 30) )
def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, cuda_device=DEFAULT_CUDA_DEVICE, model_file=None): SysPolicy.__init__(self) if not os.path.isfile(archive_file): if not model_file: raise Exception("No model for LaRL is specified!") archive_file = cached_path(model_file) temp_path = tempfile.mkdtemp() zip_ref = zipfile.ZipFile(archive_file, 'r') zip_ref.extractall(temp_path) zip_ref.close() self.prev_state = init_state() self.prev_active_domain = None domain_name = 'object_division' domain_info = domain.get_domain(domain_name) data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data') train_data_path = os.path.join(data_path, 'norm-multi-woz', 'train_dials.json') if not os.path.exists(train_data_path): zipped_file = os.path.join(data_path, 'norm-multi-woz.zip') archive = zipfile.ZipFile(zipped_file, 'r') archive.extractall(data_path) norm_multiwoz_path = os.path.join(data_path, 'norm-multi-woz') with open(os.path.join(norm_multiwoz_path, 'input_lang.index2word.json')) as f: self.input_lang_index2word = json.load(f) with open(os.path.join(norm_multiwoz_path, 'input_lang.word2index.json')) as f: self.input_lang_word2index = json.load(f) with open(os.path.join(norm_multiwoz_path, 'output_lang.index2word.json')) as f: self.output_lang_index2word = json.load(f) with open(os.path.join(norm_multiwoz_path, 'output_lang.word2index.json')) as f: self.output_lang_word2index = json.load(f) config = Pack( seed=10, train_path=train_data_path, max_vocab_size=1000, last_n_model=5, max_utt_len=50, max_dec_len=50, backward_size=2, batch_size=1, use_gpu=True, op='adam', init_lr=0.001, l2_norm=1e-05, momentum=0.0, grad_clip=5.0, dropout=0.5, max_epoch=100, embed_size=100, num_layers=1, utt_rnn_cell='gru', utt_cell_size=300, bi_utt_cell=True, enc_use_attn=True, dec_use_attn=True, dec_rnn_cell='lstm', dec_cell_size=300, dec_attn_mode='cat', y_size=10, k_size=20, beta=0.001, simple_posterior=True, contextual_posterior=True, use_mi=False, use_pr=True, use_diversity=False, # beam_size=20, fix_batch=True, fix_train_batch=False, avg_type='word', print_step=300, ckpt_step=1416, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=None, k=domain_info.input_length(), init_range=0.1, pretrain_folder='2019-09-20-21-43-06-sl_cat', forward_only=False ) config.use_gpu = config.use_gpu and torch.cuda.is_available() self.corpus = corpora_inference.NormMultiWozCorpus(config) self.model = SysPerfectBD2Cat(self.corpus, config) self.config = config if config.use_gpu: self.model.load_state_dict(torch.load( os.path.join(temp_path, 'larl_model/best-model'))) self.model.cuda() else: self.model.load_state_dict(torch.load(os.path.join( temp_path, 'larl_model/best-model'), map_location=lambda storage, loc: storage)) self.model.eval() self.dic = pickle.load( open(os.path.join(temp_path, 'larl_model/svdic.pkl'), 'rb'))
def _prepare_batch(self, selected_index): rows = [self.data[idx] for idx in selected_index] ctx_utts, ctx_lens = [], [] out_utts, out_lens = [], [] out_bs, out_db = [], [] goals, goal_lens = [], [[] for _ in range(len(self.domains))] keys = [] for row in rows: in_row, out_row, goal_row = row.context, row.response, row.goal # source context keys.append(row.key) batch_ctx = [] for turn in in_row: batch_ctx.append( self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) ctx_utts.append(batch_ctx) ctx_lens.append(len(batch_ctx)) # target response out_utt = [t for idx, t in enumerate(out_row.utt)] out_utts.append(out_utt) out_lens.append(len(out_utt)) out_bs.append(out_row.bs) out_db.append(out_row.db) # goal goals.append(goal_row) for i, d in enumerate(self.domains): goal_lens[i].append(len(goal_row[d])) batch_size = len(ctx_lens) vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns max_ctx_len = np.max(vec_ctx_lens) vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) vec_out_bs = np.array(out_bs) # (batch_size, 94) vec_out_db = np.array(out_db) # (batch_size, 30) vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens max_out_len = np.max(vec_out_lens) vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32) max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens ], [min(ls) for ls in goal_lens] if max_goal_lens != min_goal_lens: print('Fatal Error!') exit(-1) self.goal_lens = max_goal_lens vec_goals_list = [ np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens ] for b_id in range(batch_size): vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] for i, d in enumerate(self.domains): vec_goals_list[i][b_id, :] = goals[b_id][d] return Pack( context_lens=vec_ctx_lens, # (batch_size, ) contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len) output_lens=vec_out_lens, # (batch_size, ) outputs=vec_out_utts, # (batch_size, max_out_len) bs=vec_out_bs, # (batch_size, 94) db=vec_out_db, # (batch_size, 30) goals_list= vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain keys=keys)
def _process_dialogue(self, data): def transform(token_list): usr, sys = [], [] ptr = 0 while ptr < len(token_list): turn_ptr = ptr turn_list = [] while True: cur_token = token_list[turn_ptr] turn_list.append(cur_token) turn_ptr += 1 if cur_token == EOS: ptr = turn_ptr break all_sent_lens.append(len(turn_list)) if turn_list[0] == USR: usr.append(Pack(utt=turn_list, speaker=USR)) elif turn_list[0] == SYS: sys.append(Pack(utt=turn_list, speaker=SYS)) else: raise ValueError('Invalid speaker') all_dlg_lens.append(len(usr) + len(sys)) return usr, sys new_dlg = [] all_sent_lens = [] all_dlg_lens = [] for raw_dlg in data: raw_words = raw_dlg.split() # process dialogue text cur_dlg = [] words = raw_words[raw_words.index('<dialogue>') + 1:raw_words.index('</dialogue>')] words += [EOS] usr_first = True if words[0] == SYS: words = [USR, BOD, EOS] + words usr_first = True elif words[0] == USR: words = [SYS, BOD, EOS] + words usr_first = False else: print('FATAL ERROR!!! ({})'.format(words)) exit(-1) usr_utts, sys_utts = transform(words) for usr_turn, sys_turn in zip(usr_utts, sys_utts): if usr_first: cur_dlg.append(usr_turn) cur_dlg.append(sys_turn) else: cur_dlg.append(sys_turn) cur_dlg.append(usr_turn) if len(usr_utts) - len(sys_utts) == 1: cur_dlg.append(usr_utts[-1]) elif len(sys_utts) - len(usr_utts) == 1: cur_dlg.append(sys_utts[-1]) # process goal (6 digits) # FIXME FATAL ERROR HERE !!! cur_goal = raw_words[raw_words.index('<partner_input>') + 1:raw_words.index('</partner_input>')] # cur_goal = raw_words[raw_words.index('<input>')+1: raw_words.index('</input>')] if len(cur_goal) != 6: print('FATAL ERROR!!! ({})'.format(cur_goal)) exit(-1) # process outcome (6 tokens) cur_out = raw_words[raw_words.index('<output>') + 1:raw_words.index('</output>')] if len(cur_out) != 6: print('FATAL ERROR!!! ({})'.format(cur_out)) exit(-1) new_dlg.append(Pack(dlg=cur_dlg, goal=cur_goal, out=cur_out)) print('Max utt len = %d, mean utt len = %.2f' % (np.max(all_sent_lens), float(np.mean(all_sent_lens)))) print('Max dlg len = %d, mean dlg len = %.2f' % (np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) return new_dlg
def main(): start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[START]', start_time, '=' * 30) env = 'gpu' pretrained_folder = '2018-11-13-21-27-21-sys_sl_bdu2resp' pretrained_model_id = 61 exp_dir = os.path.join('sys_config_log_model', pretrained_folder, 'rl-' + start_time) # create exp folder if not os.path.exists(exp_dir): os.mkdir(exp_dir) # RL configuration rl_config = Pack( train_path='../data/norm-multi-woz/train_dials.json', valid_path='../data/norm-multi-woz/val_dials.json', test_path='../data/norm-multi-woz/test_dials.json', sv_config_path=os.path.join('sys_config_log_model', pretrained_folder, 'config.json'), sv_model_path=os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id)), rl_config_path=os.path.join(exp_dir, 'rl_config.json'), rl_model_path=os.path.join(exp_dir, 'rl_model'), ppl_best_model_path=os.path.join(exp_dir, 'ppl_best.model'), reward_best_model_path=os.path.join(exp_dir, 'reward_best.model'), record_path=exp_dir, record_freq=200, sv_train_freq= 1000, # TODO pay attention to main.py, cuz it is also controlled there use_gpu=env == 'gpu', nepoch=10, nepisode=0, max_words=100, episode_repeat=1.0, temperature=1.0, rl_lr=0.01, momentum=0.0, nesterov=False, gamma=0.99, rl_clip=5.0, random_seed=10, ) # save configuration with open(rl_config.rl_config_path, 'w') as f: json.dump(rl_config, f, indent=4) # set random seed set_seed(rl_config.random_seed) # load previous supervised learning configuration and corpus sv_config = Pack(json.load(open(rl_config.sv_config_path))) sv_config['use_gpu'] = rl_config.use_gpu corpus = NormMultiWozCorpus(sv_config) # TARGET AGENT sys_model = SysPerfectBD2Word(corpus, sv_config) if sv_config.use_gpu: sys_model.cuda() sys_model.load_state_dict( th.load(rl_config.sv_model_path, map_location=lambda storage, location: storage)) sys_model.eval() sys = OfflineRlAgent(sys_model, corpus, rl_config, name='System', tune_pi_only=False) # start RL reinforce = OfflineTaskReinforce(sys, corpus, sv_config, sys_model, rl_config, task_generate) reinforce.run() # save sys model th.save(sys_model.state_dict(), rl_config.rl_model_path) end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[END]', end_time, '=' * 30)
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var( self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder( short_ctx_utts.unsqueeze(1)) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # create decoder initial states enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) # create decoder initial states if self.simple_posterior: logits_qy, log_qy = self.c2z(enc_last) sample_y = self.gumbel_connector(logits_qy, hard=mode == GEN) log_py = self.log_uniform_y else: logits_py, log_py = self.c2z(enc_last) # encode response and use posterior to find q(z|x, c) x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) if self.contextual_posterior: logits_qy, log_qy = self.xc2z( th.cat([enc_last, x_h.squeeze(1)], dim=1)) else: logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) # use prior at inference time, otherwise use posterior if mode == GEN or (use_py is not None and use_py is True): sample_y = self.gumbel_connector(logits_py, hard=False) else: sample_y = self.gumbel_connector(logits_qy, hard=True) # pack attention context if self.config.dec_use_attn: z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) attn_context = [] temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) for z_id in range(self.y_size): attn_context.append( th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) attn_context = th.cat(attn_context, dim=1) dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) else: dec_init_state = self.z_embedding( sample_y.view(1, -1, self.config.y_size * self.config.k_size)) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # (batch_size, goal_nhid) if mode == GEN: ret_dict['sample_z'] = sample_y ret_dict['log_qy'] = log_qy return ret_dict, labels else: result = Pack(nll=self.nll(dec_outputs, labels)) # regularization qy to be uniform avg_log_qy = th.exp( log_qy.view(-1, self.config.y_size, self.config.k_size)) avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss( log_qy, unit_average=True) pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) result['pi_kl'] = pi_kl result['diversity'] = th.mean(p) result['nll'] = self.nll(dec_outputs, labels) result['b_pr'] = b_pr result['mi'] = mi return result
config = Pack( random_seed=10, train_path='../data/norm-multi-woz/train_dials.json', valid_path='../data/norm-multi-woz/val_dials.json', test_path='../data/norm-multi-woz/test_dials.json', max_vocab_size=1000, max_utt_len=50, max_dec_len=50, last_n_model=5, backward_size=2, batch_size=32, use_gpu=True, op='adam', init_lr=0.001, l2_norm=1e-05, momentum=0.0, grad_clip=5.0, dropout=0.5, max_epoch=100, embed_size=100, num_layers=1, utt_rnn_cell='gru', utt_cell_size=300, bi_utt_cell=True, enc_use_attn=True, dec_use_attn=False, dec_rnn_cell='lstm', # must be same as ctx_cell_size due to the passed initial state dec_cell_size=300, # must be same as ctx_cell_size due to the passed initial state dec_attn_mode='cat', # beam_size=20, fix_batch=True, fix_train_batch=False, avg_type='word', print_step=500, ckpt_step=1771, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=None, k=domain_info.input_length(), init_range=0.1, pretrain_folder='2018-11-13-21-27-21-sys_sl_bdu2resp', forward_only=False)
config = Pack( seed=10, train_path=train_data_path, max_vocab_size=1000, last_n_model=5, max_utt_len=50, max_dec_len=50, backward_size=2, batch_size=1, use_gpu=True, op='adam', init_lr=0.001, l2_norm=1e-05, momentum=0.0, grad_clip=5.0, dropout=0.5, max_epoch=100, embed_size=100, num_layers=1, utt_rnn_cell='gru', utt_cell_size=300, bi_utt_cell=True, enc_use_attn=True, dec_use_attn=True, dec_rnn_cell='lstm', dec_cell_size=300, dec_attn_mode='cat', y_size=10, k_size=20, beta=0.001, simple_posterior=True, contextual_posterior=True, use_mi=False, use_pr=True, use_diversity=False, # beam_size=20, fix_batch=True, fix_train_batch=False, avg_type='word', print_step=300, ckpt_step=1416, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=None, k=domain_info.input_length(), init_range=0.1, pretrain_folder='2019-09-20-21-43-06-sl_cat', forward_only=False )
def predict_response(self, state): history = [] for i in range(len(state['history'])): for j in range(len(state['history'][i])): history.append(state['history'][i][j]) e_idx = len(history) s_idx = max(0, e_idx - self.config.backward_size) context = [] for turn in history[s_idx: e_idx]: # turn = pad_to(config.max_utt_len, turn, do_pad=False) context.append(turn) if len(state['history']) == 1: self.prev_state = init_state() prepared_data = {} prepared_data['context'] = [] prepared_data['response'] = {} prev_bstate = deepcopy(self.prev_state['belief_state']) state_history = state['history'] bstate = deepcopy(state['belief_state']) # mark_not_mentioned(prev_state) active_domain = self.get_active_domain( self.prev_active_domain, prev_bstate, bstate) domain_mark_not_mentioned(bstate, active_domain) top_results, num_results = None, None for usr in context: words = usr.split() usr = delexicalize.delexicalise(' '.join(words), self.dic) # parsing reference number GIVEN belief state usr = delexicaliseReferenceNumber(usr, bstate) # changes to numbers only here digitpat = re.compile('\d+') usr = re.sub(digitpat, '[value_count]', usr) # add database pointer pointer_vector, top_results, num_results = addDBPointer(bstate) # add booking pointer pointer_vector = addBookingPointer(bstate, pointer_vector) belief_summary = get_summary_bstate(bstate) usr_utt = [BOS] + usr.split() + [EOS] packed_val = {} packed_val['bs'] = belief_summary packed_val['db'] = pointer_vector packed_val['utt'] = self.corpus._sent2id(usr_utt) prepared_data['context'].append(packed_val) prepared_data['response']['bs'] = prepared_data['context'][-1]['bs'] prepared_data['response']['db'] = prepared_data['context'][-1]['db'] results = [Pack(context=prepared_data['context'], response=prepared_data['response'])] data_feed = prepare_batch_gen(results, self.config) outputs = self.model_predict(data_feed) if active_domain is not None and active_domain in num_results: num_results = num_results[active_domain] else: num_results = 0 if active_domain is not None and active_domain in top_results: top_results = {active_domain: top_results[active_domain]} else: top_results = {} state_with_history = deepcopy(bstate) state_with_history['history'] = deepcopy(state_history) response = self.populate_template( outputs, top_results, num_results, state_with_history) import pprint pprint.pprint("============") pprint.pprint('usr:'******'agent:') pprint.pprint(response) pprint.pprint("============") return response, active_domain