def _process_dialogue(self, data): new_dlgs = [] all_sent_lens = [] all_dlg_lens = [] for key, raw_dlg in data.items(): norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)] for t_id in range(len(raw_dlg['db'])): usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS] sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS] norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id])) all_sent_lens.extend([len(usr_utt), len(sys_utt)]) # To stop dialog norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) # if self.config.to_learn == 'usr': # norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size)) all_dlg_lens.append(len(raw_dlg['db'])) processed_goal = self._process_goal(raw_dlg['goal']) new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key)) self.logger.info('Max utt len = %d, mean utt len = %.2f' % ( np.max(all_sent_lens), float(np.mean(all_sent_lens)))) self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % ( np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) return new_dlgs
def _to_id_corpus(self, name, data): results = [] for dlg in data: if len(dlg.dlg) < 1: continue id_dlg = [] for turn, parsed_turn in zip(dlg.dlg, dlg.parsed_dlg): id_turn = Pack( utt=self._sent2id(turn.utt), speaker=turn.speaker, parsed=self._goal2id(parsed_turn), ) id_dlg.append(id_turn) id_goal = self._goal2id(dlg.goal) id_out = self._outcome2id(dlg.out) # data added for debugging and PR id_partner_goal = self._goal2id(dlg.usr_goal) results.append( Pack( dlg=id_dlg, goal=id_goal, out=id_out, partner_goal=id_partner_goal, valid_partner_goals=get_valid_contexts_ints(id_goal), partitions=get_latent_powerset(id_goal), dlg_text=dlg.dlg, )) return results
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, get_marginals=False): clf = False if not clf: ctx_lens = data_feed['context_lens'] # (batch_size, ) ctx_utts = self.np2var(data_feed['contexts'], LONG) # (batch_size, max_ctx_len, max_utt_len) ctx_confs = self.np2var(data_feed['context_confs'], FLOAT) # (batch_size, max_ctx_len) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) goals = self.np2var(data_feed['goals'], LONG) # (batch_size, goal_len) batch_size = len(ctx_lens) # encode goal info goals_h = self.goal_encoder(goals) # (batch_size, goal_nhid) enc_inputs, _, _ = self.utt_encoder(ctx_utts, feats=ctx_confs, goals=goals_h) # (batch_size, max_ctx_len, num_directions*utt_cell_size) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # pack attention context if self.config.dec_use_attn: attn_context = enc_outs else: attn_context = None # create decoder initial states dec_init_state = self.connector(enc_last) # decode dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, goal_hid=goals_h) # (batch_size, goal_nhid) if get_marginals: return Pack( dec_outputs = dec_outputs, labels = labels, ) if mode == GEN: return ret_dict, labels if return_latent: return Pack(nll=self.nll(dec_outputs, labels), latent_action=dec_init_state) else: return Pack(nll=self.nll(dec_outputs, labels))
def forward(self, data_feed, mode, clf=False, gen_type='greedy', return_latent=False, use_py=True): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var( self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder( short_ctx_utts.unsqueeze(1)) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # pack attention context if self.config.dec_use_attn: attn_context = enc_outs else: attn_context = None # create decoder initial states dec_init_state = self.policy( th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)).unsqueeze(0) # decode if self.config.dec_rnn_cell == 'lstm': # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0) dec_init_state = tuple([dec_init_state, dec_init_state]) dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # (batch_size, goal_nhid) if mode == GEN: return ret_dict, labels if return_latent: return Pack(nll=self.nll(dec_outputs, labels), latent_action=dec_init_state) else: return Pack(nll=self.nll(dec_outputs, labels))
def transform(token_list, usr_goal, sys_goal): usr, sys = [], [] parsed_usr, parsed_sys = [], [] num_proposals = 0 ptr = 0 while ptr < len(token_list): turn_ptr = ptr turn_list = [] while True: cur_token = token_list[turn_ptr] turn_list.append(cur_token) turn_ptr += 1 if cur_token == EOS: ptr = turn_ptr break all_sent_lens.append(len(turn_list)) if turn_list[0] == USR: usr.append(Pack(utt=turn_list, speaker=USR)) # assume usr is agent 0 prop = parse_c(" ".join(turn_list), usr_goal, merge=True).proposal parsed = [ str(x) for x in ([ prop[0]["book"], prop[0]["hat"], prop[0]["ball"], prop[1]["book"], prop[1]["hat"], prop[1]["ball"], ] if prop is not None else [-1] * 6) ] parsed_usr.append(parsed) num_proposals = num_proposals + 1 if prop is not None else num_proposals elif turn_list[0] == SYS: sys.append(Pack(utt=turn_list, speaker=SYS)) # assume sys is agent 1 prop = parse_c(" ".join(turn_list), sys_goal, merge=True).proposal parsed = [ str(x) for x in ([ prop[1]["book"], prop[1]["hat"], prop[1]["ball"], prop[0]["book"], prop[0]["hat"], prop[0]["ball"], ] if prop is not None else [-1] * 6) ] parsed_sys.append(parsed) num_proposals = num_proposals + 1 if prop is not None else num_proposals else: raise ValueError('Invalid speaker') all_dlg_lens.append(len(usr) + len(sys)) return usr, sys, parsed_usr, parsed_sys, num_proposals
def _to_id_corpus(self, name, data): results = [] for dlg in data: if len(dlg.dlg) < 1: continue id_dlg = [] for turn in dlg.dlg: id_turn = Pack(utt=self._sent2id(turn.utt), speaker=turn.speaker, db=turn.db, bs=turn.bs) id_dlg.append(id_turn) id_goal = self._goal2id(dlg.goal) results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key)) return results
def _to_id_corpus(self, name, data): results = [] for dlg in data: if len(dlg.dlg) < 1: continue id_dlg = [] for turn in dlg.dlg: id_turn = Pack(utt=self._sent2id(turn.utt), speaker=turn.speaker) id_dlg.append(id_turn) id_goal = self._goal2id(dlg.goal) id_out = self._outcome2id(dlg.out) results.append(Pack(dlg=id_dlg, goal=id_goal, out=id_out)) return results
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # create decoder initial states enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) # create decoder initial states if self.simple_posterior: q_mu, q_logvar = self.c2z(enc_last) sample_z = self.gauss_connector(q_mu, q_logvar) p_mu, p_logvar = self.zero, self.zero else: p_mu, p_logvar = self.c2z(enc_last) # encode response and use posterior to find q(z|x, c) x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) # use prior at inference time, otherwise use posterior if mode == GEN or use_py: sample_z = self.gauss_connector(p_mu, p_logvar) else: sample_z = self.gauss_connector(q_mu, q_logvar) # pack attention context dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, dec_inputs=dec_inputs, dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # (batch_size, goal_nhid) if mode == GEN: ret_dict['sample_z'] = sample_z return ret_dict, labels else: result = Pack(nll=self.nll(dec_outputs, labels)) pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) result['pi_kl'] = pi_kl result['nll'] = self.nll(dec_outputs, labels) return result
def _prepare_batch_flat(self, selected_index): rows = [self.data[idx] for idx in selected_index] ctx_utts, ctx_lens = [], [] out_utts, out_lens = [], [] goals, goal_lens = [], [] for row in rows: in_row, out_row, goal_row = row.context, row.response, row.goal # source context batch_ctx = [] for turn in in_row: batch_ctx.append( self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) ctx_utts.append(batch_ctx) ctx_lens.append(len(batch_ctx)) # target response out_utt = [t for idx, t in enumerate(out_row.utt)] out_utts.append(out_utt) out_lens.append(len(out_utt)) # goal goals.append(goal_row) goal_lens.append(len(goal_row)) vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns max_ctx_len = np.max(vec_ctx_lens) vec_ctx_utts = np.zeros( (self.batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) # confs is used to add some hand-crafted features vec_ctx_confs = np.ones((self.batch_size, max_ctx_len), dtype=np.float32) vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens max_out_len = np.max(vec_out_lens) vec_out_utts = np.zeros((self.batch_size, max_out_len), dtype=np.int32) max_goal_len, min_goal_len = max(goal_lens), min(goal_lens) if max_goal_len != min_goal_len or max_goal_len != 6: print('FATAL ERROR!') exit(-1) self.goal_len = max_goal_len vec_goals = np.zeros((self.batch_size, self.goal_len), dtype=np.int32) for b_id in range(self.batch_size): vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] vec_goals[b_id, :] = goals[b_id] return Pack(context_lens=vec_ctx_lens, \ contexts=vec_ctx_utts, \ context_confs=vec_ctx_confs, \ output_lens=vec_out_lens, \ outputs=vec_out_utts, \ goals=vec_goals)
def transform(token_list): usr, sys = [], [] ptr = 0 while ptr < len(token_list): turn_ptr = ptr turn_list = [] while True: cur_token = token_list[turn_ptr] turn_list.append(cur_token) turn_ptr += 1 if cur_token == EOS: ptr = turn_ptr break all_sent_lens.append(len(turn_list)) if turn_list[0] == USR: usr.append(Pack(utt=turn_list, speaker=USR)) elif turn_list[0] == SYS: sys.append(Pack(utt=turn_list, speaker=SYS)) else: raise ValueError('Invalid speaker') all_dlg_lens.append(len(usr) + len(sys)) return usr, sys
def flatten_dialog_seq(self, data, backward_size): """ Turn each dialog in list of dialogs into a list of context, response pairs. Backward_size indicates how many previous utterances to condition on. This should be limited to 1 or 2 at most limiting dependencies. The speaker is SYS, so USR utterances are not modeled. """ results = [] for dlg in data: goal = dlg.goal context_responses = [] parsed_context_responses = [] for i in range(1, len(dlg.dlg)): if dlg.dlg[i].speaker == USR: continue e_idx = i s_idx = max(0, e_idx - backward_size) response = dlg.dlg[i].copy() response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) parsed_response = response.parsed context = [] parsed_context = [] for turn in dlg.dlg[s_idx:e_idx]: turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) context.append(turn) parsed_context.append(turn.parsed) context_responses.append( Pack( context=context, response=response, goal=goal, parsed_context=parsed_context, parsed_response=parsed_response, partner_goal=dlg.partner_goal, valid_partner_goals=dlg.valid_partner_goals, partitions=dlg.partitions, )) results.append(context_responses) return results
def flatten_dialog(self, data, backward_size): results = [] for dlg in data: goal = dlg.goal for i in range(1, len(dlg.dlg)): if dlg.dlg[i].speaker == USR: continue e_idx = i s_idx = max(0, e_idx - backward_size) response = dlg.dlg[i].copy() response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) context = [] for turn in dlg.dlg[s_idx:e_idx]: turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) context.append(turn) results.append( Pack(context=context, response=response, goal=goal)) return results
def flatten_dialog(self, data, backward_size): results = [] indexes = [] batch_indexes = [] resp_set = set() for dlg in data: goal = dlg.goal key = dlg.key batch_index = [] for i in range(1, len(dlg.dlg)): if dlg.dlg[i].speaker == USR: continue e_idx = i s_idx = max(0, e_idx - backward_size) response = dlg.dlg[i].copy() response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False) resp_set.add(json.dumps(response.utt)) context = [] for turn in dlg.dlg[s_idx:e_idx]: turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False) context.append(turn) results.append( Pack(context=context, response=response, goal=goal, key=key)) indexes.append(len(indexes)) batch_index.append(indexes[-1]) if len(batch_index) > 0: batch_indexes.append(batch_index) print("Unique resp {}".format(len(resp_set))) return results, indexes, batch_indexes
config = Pack( random_seed=10, train_path='../data/negotiate/train.txt', val_path='../data/negotiate/val.txt', test_path='../data/negotiate/test.txt', last_n_model=4, max_utt_len=20, #backward_size = 14, backward_size=8, #batch_size = 16, batch_size=4, use_gpu=True, op='adam', init_lr=0.001, l2_norm=0.00001, momentum=0.0, grad_clip=10.0, dropout=0.3, max_epoch=50, embed_size=256, num_layers=1, #num_layers = 2, utt_rnn_cell='gru', utt_cell_size=128, bi_utt_cell=True, enc_use_attn=False, ctx_rnn_cell='gru', ctx_cell_size=256, bi_ctx_cell=False, #dec_use_attn = True, dec_use_attn=False, dec_rnn_cell= 'gru', # must be same as ctx_cell_size due to the passed initial state dec_cell_size= 256, # must be same as ctx_cell_size due to the passed initial state dec_attn_mode='cat', # beam_size=20, fix_train_batch=False, avg_type='real_word', print_step=100, ckpt_step=400, #ckpt_step = 2523, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=50, max_dec_len=40, k=domain_info.input_length(), goal_embed_size=64, goal_nhid=64, init_range=0.1, pretrain_folder='2019-12-08-18-45-47-sl_word_dlg_num', forward_only=False, #forward_only = True, # different batching style seq=True, # use oracle context and proposal parse oracle_context=True, #oracle_context = False, #oracle_parse = False, oracle_parse=True, semisupervised=False, #prop_weight = 0.1, prop_weight=1, #prop_weight = 0, tie_prop_utt_enc=False, )
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, get_marginals=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) ctx_utts = self.np2var(data_feed['contexts'], LONG) # (batch_size, max_ctx_len, max_utt_len) ctx_confs = self.np2var(data_feed['context_confs'], FLOAT) # (batch_size, max_ctx_len) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) goals = self.np2var(data_feed['goals'], LONG) # (batch_size, goal_len) batch_size = len(ctx_lens) # encode goal info goals_h = self.goal_encoder(goals) # (batch_size, goal_nhid) enc_inputs, _, _ = self.utt_encoder( ctx_utts, feats=ctx_confs, goals= goals_h, # (batch_size, max_ctx_len, num_directions*utt_cell_size) ) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None) partitions = self.np2var(data_feed.partitions, LONG) num_partitions = self.np2var(data_feed.num_partitions, INT) # oracle input partner_goals = self.np2var(data_feed.true_partner_goals, LONG) parsed_outputs = self.np2var(data_feed.parsed_outputs, LONG) # true partner item values partner_goals_h = self.goal_encoder(partner_goals) # proposal prediction prop_enc_inputs, _, _ = self.prop_utt_encoder( ctx_utts, feats=ctx_confs, goals= goals_h, # (batch_size, max_ctx_len, num_directions*utt_cell_size) ) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) prop_enc_outs, prop_enc_last = self.prop_ctx_encoder( enc_inputs if self.config.tie_prop_utt_enc else prop_enc_inputs, input_lengths=ctx_lens, goals=partner_goals_h if self.config.oracle_context else None, ) my_state_emb_out = self.res_layer_out( th.cat([ self.book_emb_out(partitions[:, :, 0]), self.hat_emb_out(partitions[:, :, 1]), self.ball_emb_out(partitions[:, :, 2]), ], -1)) your_state_emb_out = self.res_layer_out( th.cat([ self.book_emb_out(partitions[:, :, 3]), self.hat_emb_out(partitions[:, :, 4]), self.ball_emb_out(partitions[:, :, 5]), ], -1)) state_emb_out = th.cat([my_state_emb_out, your_state_emb_out], -1) big_goals_h = self.res_goal_mlp( th.cat([ goals_h.unsqueeze(1).expand( state_emb_out.shape[0], state_emb_out.shape[1], goals_h.shape[-1], ), state_emb_out, ], -1)) import pdb pdb.set_trace() z_size = partitions.shape[1] prop_mask = (partitions == parsed_outputs.unsqueeze(1)).all(-1) logits_prop = th.einsum("nsh,nh->ns", state_emb_out, prop_enc_last[-1]) mask = ~(th.arange( z_size, device=num_partitions.device, dtype=num_partitions.dtype).repeat( partitions.shape[0], 1) < num_partitions.unsqueeze(-1)) logp_prop = logits_prop.masked_fill(mask, float("-inf")).log_softmax( -1) # get decoder inputs if self.config.semisupervised: # re-use params or make new ones? re-using can only hurt # TODO: use new parameters my_state_emb = self.res_layer( th.cat([ self.book_emb(partitions[:, :, 0]), self.hat_emb(partitions[:, :, 1]), self.ball_emb(partitions[:, :, 2]), ], -1)) your_state_emb = self.res_layer( th.cat([ self.book_emb(partitions[:, :, 3]), self.hat_emb(partitions[:, :, 4]), self.ball_emb(partitions[:, :, 5]), ], -1)) noise_state_emb = th.cat([my_state_emb, your_state_emb], -1) logp_tprop_prop = th.einsum("nth,nsh->nts", noise_state_emb, state_emb_out).log_softmax(1) nll_prop = -self.config.prop_weight * ( logp_tprop_prop + logp_prop.unsqueeze(-2)).logsumexp(-1)[prop_mask].mean() else: nll_prop = -self.config.prop_weight * logp_prop[prop_mask].mean() dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # pack attention context if self.config.dec_use_attn: attn_context = enc_outs else: attn_context = None # create decoder initial states dec_init_state = self.connector( enc_last) if self.out_backward_size is not None else None if mode == GEN: N, Z, H = big_goals_h.shape if gen_type == "sampled": sampled_proposal_indices = logp_prop.exp().multinomial(1) elif gen_type == "greedy": sampled_proposal_indices = logp_prop.argmax(-1) else: raise ValueError(f"Unknown gen_type: {gen_type}") sampled_goals_h = big_goals_h.gather( 1, sampled_proposal_indices.view(N, 1, 1).expand(N, 1, H)).squeeze(1) dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, # my goal, your goal, and the proposal!!! a lot goal_hid=sampled_goals_h, #goal_hid=big_goals_h[prop_mask], ) # (batch_size, goal_nhid) return ret_dict, labels # decode N, T = dec_inputs.shape dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size * z_size, dec_inputs=dec_inputs.repeat(1, z_size).view(-1, T), # (batch_size, response_size-1) dec_init_state=dec_init_state.repeat(1, 1, z_size).view( 1, z_size * batch_size, -1) if dec_init_state is not None else None, attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, # my goal, your goal, and the proposal!!! a lot goal_hid=big_goals_h.view(-1, 128), ) # (batch_size, goal_nhid) V = dec_outputs.shape[-1] #logp_w_prop = dec_outputs.view(N, z_size, T, V)[prop_mask] T_out = dec_outputs.shape[-2] logp_w_prop = (dec_outputs.view(N, z_size, T_out, V) + logp_prop.view(N, z_size, 1, 1)) logp_w = logp_w_prop.logsumexp(1) if get_marginals: N, Z, T, V = logp_w_prop.shape logp_prop_w = logp_w_prop.gather( -1, labels.view(N, 1, T, 1).expand(N, Z, T, 1), ).squeeze(-1).sum(-1).log_softmax(1) best_prop_model = logp_prop_w.argmax(-1) parsed_prop = prop_mask.argmax(-1) if self.config.semisupervised: logp_tprop = (logp_tprop_prop + logp_prop.unsqueeze(-2)).logsumexp(-1) best_tprop_model = logp_tprop.argmax(-1) out_utts_text = [[ self.vocab[x] for x in xs if x != self.vocab_dict["<pad>"] ] for xs in out_utts] ctx_utts_text = [[ self.vocab[x] for xs in xss for x in xs if x != self.vocab_dict["<pad>"] ] for xss in ctx_utts] def get(i): print(partitions[i][best_prop_model[i]]) if self.config.semisupervised: print(partitions[i][best_tprop_model[i]]) print(partitions[i][parsed_prop[i]]) print(" ".join(out_utts_text[i])) print(" ".join(ctx_utts_text[i])) #import pdb; pdb.set_trace() """ return Pack( dec_outputs = dec_outputs, labels = labels, logp_prop = logp_prop, log_marginals_prop = log_marginals_prop, logp_w_prop = logp_w_prop, logp_w = logp_w, ) """ return Pack( nll=self.nll(logp_w, labels), nll_prop=nll_prop, ) if mode == GEN: return ret_dict, labels if return_latent: return Pack(nll=self.nll(logp_w, labels), latent_action=dec_init_state) else: return Pack(nll=self.nll(logp_w, labels), nll_prop=nll_prop)
config = Pack( train_path='../data/negotiate/train.txt', val_path='../data/negotiate/val.txt', test_path='../data/negotiate/test.txt', last_n_model=4, max_utt_len=20, #backward_size = 14, backward_size=8, #backward_size = 1, #batch_size = 32, batch_size=4, grad_clip=10.0, use_gpu=True, op='adam', init_lr=0.001, l2_norm=0.00001, momentum=0.0, dropout=0.3, max_epoch=100, embed_size=256, #num_layers = 1, num_layers=2, utt_rnn_cell='gru', utt_cell_size=128, bi_utt_cell=True, enc_use_attn=False, ctx_rnn_cell='gru', ctx_cell_size=256, bi_ctx_cell=False, z_size=128, #beta = 0.01, #simple_posterior = False, #use_pr = True, dec_use_attn=False, dec_rnn_cell= 'gru', # must be same as ctx_cell_size due to the passed initial state dec_cell_size= 256, # must be same as ctx_cell_size due to the passed initial state dec_attn_mode='cat', # fix_train_batch=False, fix_batch=False, beam_size=20, avg_type='real_word', print_step=100, ckpt_step=400, #ckpt_step = 2523, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=1, max_dec_len=40, k=domain_info.input_length(), goal_embed_size=64, goal_nhid=64, init_range=0.1, pretrain_folder='2019-12-06-02-20-58-sl_hmm', forward_only=False, #forward_only = True, # options for sequence LVMs seq=True, noisy_proposal_labels=True, sup_proposal_labels=False, #sup_proposal_labels = True, label_weight=0.1, #label_weight = 1, )
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) ctx_utts = self.np2var(data_feed['contexts'], LONG) # (batch_size, max_ctx_len, max_utt_len) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) goals = self.np2var(data_feed['goals'], LONG) # (batch_size, goal_len) batch_size = len(ctx_lens) # encode goal info goals_h = self.goal_encoder(goals) # (batch_size, goal_nhid) enc_inputs, _, _ = self.utt_encoder(ctx_utts, goals=goals_h) # (batch_size, max_ctx_len, num_directions*utt_cell_size) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # create decoder initial states if self.simple_posterior: logits_qy, log_qy = self.c2z(enc_last) sample_y = self.gumbel_connector(logits_qy) log_py = self.log_uniform_y else: logits_py, log_py = self.c2z(enc_last) # encode response and use posterior to find q(z|x, c) x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1), goals=goals_h) logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1).unsqueeze(0)], dim=2)) # use prior at inference time, otherwise use posterior if mode == GEN or use_py: sample_y = self.gumbel_connector(logits_py) else: sample_y = self.gumbel_connector(logits_qy) # pack attention context if self.config.dec_use_attn: z_embeddings = th.t(self.z_embedding.weight).split(self.config.k_size, dim=0) attn_context = [] temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) for z_id in range(self.config.y_size): attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) attn_context = th.cat(attn_context, dim=1) dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) else: attn_context = None dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) # decode dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, goal_hid=goals_h) # (batch_size, goal_nhid) if mode == GEN: return ret_dict, labels else: # regularization qy to be uniform avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) pi_h = self.entropy_loss(log_qy, unit_average=True) results = Pack(nll=self.nll(dec_outputs, labels), mi=mi, pi_kl=pi_kl, pi_h=pi_h) if return_latent: results['latent_action'] = dec_init_state return results
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) ctx_utts = self.np2var(data_feed['contexts'], LONG) # (batch_size, max_ctx_len, max_utt_len) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) goals = self.np2var(data_feed['goals'], LONG) # (batch_size, goal_len) batch_size = len(ctx_lens) # encode goal info goals_h = self.goal_encoder(goals) # (batch_size, goal_nhid) enc_inputs, _, _ = self.utt_encoder(ctx_utts, goals=goals_h) # (batch_size, max_ctx_len, num_directions*utt_cell_size) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # create decoder initial states if self.simple_posterior: q_mu, q_logvar = self.c2z(enc_last) sample_z = self.gauss_connector(q_mu, q_logvar) p_mu, p_logvar = self.zero, self.zero else: p_mu, p_logvar = self.c2z(enc_last) # encode response and use posterior to find q(z|x, c) x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1), goals=goals_h) q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1).unsqueeze(0)], dim=2)) # use prior at inference time, otherwise use posterior if mode == GEN or use_py: sample_z = self.gauss_connector(p_mu, p_logvar) else: sample_z = self.gauss_connector(q_mu, q_logvar) # pack attention context dec_init_state = self.z_embedding(sample_z) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) # decode dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, goal_hid=goals_h) # (batch_size, goal_nhid) if mode == GEN: ret_dict['sample_z'] = sample_z return ret_dict, labels else: result = Pack(nll=self.nll(dec_outputs, labels)) pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar) result['pi_kl'] = pi_kl result['nll'] = self.nll(dec_outputs, labels) return result
config = Pack( random_seed=10, train_path='../data/norm-multi-woz/train_dials.json', valid_path='../data/norm-multi-woz/val_dials.json', test_path='../data/norm-multi-woz/test_dials.json', max_vocab_size=1000, max_utt_len=50, max_dec_len=50, last_n_model=5, backward_size=2, batch_size=32, use_gpu=True, op='adam', init_lr=0.001, l2_norm=1e-05, momentum=0.0, grad_clip=5.0, dropout=0.5, max_epoch=50, embed_size=100, num_layers=1, utt_rnn_cell='gru', utt_cell_size=300, bi_utt_cell=True, enc_use_attn=True, dec_use_attn=False, dec_rnn_cell='lstm', # must be same as ctx_cell_size due to the passed initial state dec_cell_size=300, # must be same as ctx_cell_size due to the passed initial state dec_attn_mode='cat', # beam_size=20, fix_batch=True, fix_train_batch=False, avg_type='word', print_step=500, ckpt_step=1771, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=None, k=domain_info.input_length(), init_range=0.1, pretrain_folder='2018-11-13-21-27-21-sys_sl_bdu2resp', forward_only=False)
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # create decoder initial states enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) # create decoder initial states if self.simple_posterior: logits_qy, log_qy = self.c2z(enc_last) sample_y = self.gumbel_connector(logits_qy, hard=mode==GEN) log_py = self.log_uniform_y else: logits_py, log_py = self.c2z(enc_last) # encode response and use posterior to find q(z|x, c) x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1)) if self.contextual_posterior: logits_qy, log_qy = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1)) else: logits_qy, log_qy = self.xc2z(x_h.squeeze(1)) # use prior at inference time, otherwise use posterior if mode == GEN or (use_py is not None and use_py is True): sample_y = self.gumbel_connector(logits_py, hard=False) else: sample_y = self.gumbel_connector(logits_qy, hard=True) # pack attention context if self.config.dec_use_attn: z_embeddings = th.t(self.z_embedding.weight).split(self.k_size, dim=0) attn_context = [] temp_sample_y = sample_y.view(-1, self.config.y_size, self.config.k_size) for z_id in range(self.y_size): attn_context.append(th.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) attn_context = th.cat(attn_context, dim=1) dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0) else: dec_init_state = self.z_embedding(sample_y.view(1, -1, self.config.y_size * self.config.k_size)) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size) # (batch_size, goal_nhid) if mode == GEN: ret_dict['sample_z'] = sample_y ret_dict['log_qy'] = log_qy return ret_dict, labels else: result = Pack(nll=self.nll(dec_outputs, labels)) # regularization qy to be uniform avg_log_qy = th.exp(log_qy.view(-1, self.config.y_size, self.config.k_size)) avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) b_pr = self.cat_kl_loss(avg_log_qy, self.log_uniform_y, batch_size, unit_average=True) mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss(log_qy, unit_average=True) pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) q_y = th.exp(log_qy).view(-1, self.config.y_size, self.config.k_size) # b p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2) result['pi_kl'] = pi_kl result['diversity'] = th.mean(p) result['nll'] = self.nll(dec_outputs, labels) result['b_pr'] = b_pr result['mi'] = mi return result
def _prepare_batch_seq(self, selected_index): dlgs = [self.data[idx] for idx in selected_index] dlg_idxs, dlg_lens = [], [] ctx_utts, ctx_lens = [], [] out_utts, out_lens = [], [] goals, goal_lens = [], [] partner_goals_list, num_partner_goals = [], [] partitions, num_partitions = [], [] true_partner_goals = [] parsed_out_utts = [] parsed_ctx_utts = [] # flatten dialogs here # keep pointers for i, rows in enumerate(dlgs): dlg_len = 0 for row in rows: in_row, out_row, goal_row = row.context, row.response, row.goal # source context batch_ctx = [] for turn in in_row: batch_ctx.append( self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) ctx_utts.append(batch_ctx) ctx_lens.append(len(batch_ctx)) # target response out_utt = [t for idx, t in enumerate(out_row.utt)] out_utts.append(out_utt) out_lens.append(len(out_utt)) # goal goals.append(goal_row) goal_lens.append(len(goal_row)) # valid partner goals partner_goals = row.valid_partner_goals partner_goals_list.append(partner_goals) num_partner_goals.append(len(partner_goals)) # partitions _partitions = row.partitions # list of list of tuples, each tuple is a goal # and the inner list represents all possible partner goals partitions.append(_partitions) num_partitions.append(len(_partitions)) # dialog index for getting features in sequence model dlg_idxs.append(i) # true partner goal true_partner_goals.append(row.partner_goal) # parsed features parsed_out_utts.append(out_row.parsed) parsed_ctx_utts.append([x.parsed for x in row.context]) dlg_len += 1 dlg_lens.append(dlg_len) effective_batch_size = len(goals) vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns max_ctx_len = np.max(vec_ctx_lens) vec_ctx_utts = np.zeros( (effective_batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) # confs is used to add some hand-crafted features vec_ctx_confs = np.ones((effective_batch_size, max_ctx_len), dtype=np.float32) vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens max_out_len = np.max(vec_out_lens) vec_out_utts = np.zeros((effective_batch_size, max_out_len), dtype=np.int32) max_goal_len, min_goal_len = max(goal_lens), min(goal_lens) if max_goal_len != min_goal_len or max_goal_len != 6: print('FATAL ERROR!') exit(-1) self.goal_len = max_goal_len #vec_goals = np.zeros((effective_batch_size, self.goal_len), dtype=np.int32) vec_goals = np.array(goals, dtype=np.int32) max_partner_goals = max(num_partner_goals) vec_partner_goals = np.zeros( (effective_batch_size, max_partner_goals, self.goal_len), dtype=np.int32, ) vec_num_partner_goals = np.array(num_partner_goals) # just always pad to 128, makes things easier max_partitions = max(num_partitions) #max_partitions = 128 vec_partitions = np.zeros( # 3 item types (effective_batch_size, max_partitions, 6), dtype=np.int32, ) vec_num_partitions = np.array(num_partitions) vec_dlg_idxs = np.array(dlg_idxs, dtype=np.int32) vec_dlg_lens = np.array(dlg_lens, dtype=np.int32) vec_true_partner_goals = np.array(true_partner_goals, dtype=np.int32) vec_parsed_out_utts = np.array(parsed_out_utts, dtype=np.int32) vec_parsed_ctx_utts = np.ones( (effective_batch_size, max_ctx_len, 6), dtype=np.int32, ) * 11 # [0,10] is taken for values. no numbers exceed 10, 11 is padding # for b_id in range(effective_batch_size): vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] for pg_id in range(num_partner_goals[b_id]): vec_partner_goals[b_id, pg_id, :] = partner_goals_list[b_id][pg_id] for p_id in range(num_partitions[b_id]): vec_partitions[b_id, p_id, :] = partitions[b_id][p_id] vec_parsed_ctx_utts[ b_id, :vec_ctx_lens[b_id], :] = parsed_ctx_utts[b_id] return Pack( context_lens=vec_ctx_lens, contexts=vec_ctx_utts, context_confs=vec_ctx_confs, output_lens=vec_out_lens, outputs=vec_out_utts, goals=vec_goals, partner_goals=vec_partner_goals, num_partner_goals=vec_num_partner_goals, dlg_idxs=vec_dlg_idxs, dlg_lens=vec_dlg_lens, partitions=vec_partitions, num_partitions=vec_num_partitions, # oracle values true_partner_goals=vec_true_partner_goals, parsed_contexts=vec_parsed_ctx_utts, parsed_outputs=vec_parsed_out_utts, )
def _process_dialogue(self, data): def transform(token_list): usr, sys = [], [] ptr = 0 while ptr < len(token_list): turn_ptr = ptr turn_list = [] while True: cur_token = token_list[turn_ptr] turn_list.append(cur_token) turn_ptr += 1 if cur_token == EOS: ptr = turn_ptr break all_sent_lens.append(len(turn_list)) if turn_list[0] == USR: usr.append(Pack(utt=turn_list, speaker=USR)) elif turn_list[0] == SYS: sys.append(Pack(utt=turn_list, speaker=SYS)) else: raise ValueError('Invalid speaker') all_dlg_lens.append(len(usr) + len(sys)) return usr, sys new_dlg = [] all_sent_lens = [] all_dlg_lens = [] for raw_dlg in data: raw_words = raw_dlg.split() # process dialogue text cur_dlg = [] words = raw_words[raw_words.index('<dialogue>') + 1: raw_words.index('</dialogue>')] words += [EOS] usr_first = True if words[0] == SYS: words = [USR, BOD, EOS] + words usr_first = True elif words[0] == USR: words = [SYS, BOD, EOS] + words usr_first = False else: print('FATAL ERROR!!! ({})'.format(words)) exit(-1) usr_utts, sys_utts = transform(words) for usr_turn, sys_turn in zip(usr_utts, sys_utts): if usr_first: cur_dlg.append(usr_turn) cur_dlg.append(sys_turn) else: cur_dlg.append(sys_turn) cur_dlg.append(usr_turn) if len(usr_utts) - len(sys_utts) == 1: cur_dlg.append(usr_utts[-1]) elif len(sys_utts) - len(usr_utts) == 1: cur_dlg.append(sys_utts[-1]) # process goal (6 digits) # FIXME FATAL ERROR HERE !!! cur_goal = raw_words[raw_words.index('<partner_input>') + 1: raw_words.index('</partner_input>')] # cur_goal = raw_words[raw_words.index('<input>')+1: raw_words.index('</input>')] if len(cur_goal) != 6: print('FATAL ERROR!!! ({})'.format(cur_goal)) exit(-1) # process outcome (6 tokens) cur_out = raw_words[raw_words.index('<output>') + 1: raw_words.index('</output>')] if len(cur_out) != 6: print('FATAL ERROR!!! ({})'.format(cur_out)) exit(-1) new_dlg.append(Pack(dlg=cur_dlg, goal=cur_goal, out=cur_out)) print('Max utt len = %d, mean utt len = %.2f' % ( np.max(all_sent_lens), float(np.mean(all_sent_lens)))) print('Max dlg len = %d, mean dlg len = %.2f' % ( np.max(all_dlg_lens), float(np.mean(all_dlg_lens)))) return new_dlg
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_pz=None, return_latent=False): ctx_lens = data_feed['context_lens'] # (batch_size, ) ctx_utts = self.np2var(data_feed['contexts'], LONG) # (batch_size, max_ctx_len, max_utt_len) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) goals = self.np2var(data_feed['goals'], LONG) # (batch_size, goal_len) partitions = self.np2var(data_feed.partitions, LONG) num_partitions = self.np2var(data_feed.num_partitions, INT) # effective batch size batch_size = len(ctx_lens) true_batch_size = data_feed.dlg_idxs.max() # encode goal info goals_h = self.goal_encoder(goals) # (batch_size, goal_nhid) enc_inputs, _, _ = self.utt_encoder(ctx_utts, goals=goals_h) # (batch_size, max_ctx_len, num_directions*utt_cell_size) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() logits_pz_t, log_pz_t = self.c2z(enc_last) # encode response and use posterior to find q(z|x, c) x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1), goals=goals_h) logits_qz_t, log_qz_t = self.xc2z( th.cat([enc_last, x_h.squeeze(1).unsqueeze(0)], dim=2)) state_emb = self.res_layer( self.item_emb(partitions).view(-1, self.z_size, 3 * 32)) _, psi_zr_zl = self.hmm_potentials(state_emb, lengths=num_partitions) # REMINDER: transpose last two dimensions of HMM for torch_struct import pdb pdb.set_trace() # reshape and run HMM? # decode dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, goal_hid=goals_h, # (batch_size, goal_nhid) ) if mode == GEN: return ret_dict, labels else: # regularization qy to be uniform avg_log_qy = th.exp( log_qy.view(-1, self.config.y_size, self.config.k_size)) avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15) mi = self.entropy_loss(avg_log_qy, unit_average=True) - self.entropy_loss( log_qy, unit_average=True) pi_kl = self.cat_kl_loss(log_qy, log_py, batch_size, unit_average=True) pi_h = self.entropy_loss(log_qy, unit_average=True) results = Pack(nll=self.nll(dec_outputs, labels), mi=mi, pi_kl=pi_kl, pi_h=pi_h) if return_latent: results['latent_action'] = dec_init_state return results
def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, get_marginals=False): clf = False if not clf: ctx_lens = data_feed['context_lens'] # (batch_size, ) ctx_utts = self.np2var( data_feed['contexts'], LONG) # (batch_size, max_ctx_len, max_utt_len) ctx_confs = self.np2var(data_feed['context_confs'], FLOAT) # (batch_size, max_ctx_len) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) goals = self.np2var(data_feed['goals'], LONG) # (batch_size, goal_len) batch_size = len(ctx_lens) # encode goal info goals_h = self.goal_encoder(goals) # (batch_size, goal_nhid) enc_inputs, _, _ = self.utt_encoder( ctx_utts, feats=ctx_confs, goals=goals_h ) # (batch_size, max_ctx_len, num_directions*utt_cell_size) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None) partitions = self.np2var(data_feed.partitions, LONG) num_partitions = self.np2var(data_feed.num_partitions, INT) # oracle input partner_goals = self.np2var(data_feed.true_partner_goals, LONG) parsed_outputs = self.np2var(data_feed.parsed_outputs, LONG) # true partner item values partner_goals_h = self.goal_encoder(partner_goals) # true next utterance proposal parse my_state_emb = self.res_layer( th.cat([ self.book_emb(parsed_outputs[:, 0]), self.hat_emb(parsed_outputs[:, 1]), self.ball_emb(parsed_outputs[:, 2]), ], -1)) your_state_emb = self.res_layer( th.cat([ self.book_emb(parsed_outputs[:, 3]), self.hat_emb(parsed_outputs[:, 4]), self.ball_emb(parsed_outputs[:, 5]), ], -1)) if self.config.oracle_context and self.config.oracle_parse: big_goals_h = self.res_goal_mlp( th.cat([ goals_h, partner_goals_h, my_state_emb, your_state_emb, ], -1)) elif self.config.oracle_context: big_goals_h = self.res_goal_mlp( th.cat([ goals_h, partner_goals_h, ], -1)) elif self.config.oracle_parse: big_goals_h = self.res_goal_mlp( th.cat([ goals_h, my_state_emb, your_state_emb, ], -1)) # proposal prediction prop_enc_inputs, _, _ = self.prop_utt_encoder( ctx_utts, feats=ctx_confs, goals=goals_h ) # (batch_size, max_ctx_len, num_directions*utt_cell_size) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) prop_enc_outs, prop_enc_last = self.prop_ctx_encoder( enc_inputs, input_lengths=ctx_lens, goals=None) my_state_emb_out = self.res_layer_out( th.cat([ self.book_emb_out(partitions[:, :, 0]), self.hat_emb_out(partitions[:, :, 1]), self.ball_emb_out(partitions[:, :, 2]), ], -1)) your_state_emb_out = self.res_layer_out( th.cat([ self.book_emb_out(partitions[:, :, 3]), self.hat_emb_out(partitions[:, :, 4]), self.ball_emb_out(partitions[:, :, 5]), ], -1)) state_emb_out = th.cat([my_state_emb_out, your_state_emb_out], -1) label_mask = (partitions == parsed_outputs.unsqueeze(1)).all(-1) logits_label = th.einsum("nsh,nh->ns", state_emb_out, prop_enc_last[-1]) mask = ~(th.arange(partitions.shape[1], device=num_partitions.device, dtype=num_partitions.dtype).repeat( partitions.shape[0], 1) < num_partitions.unsqueeze(-1)) logp_label = logits_label.masked_fill(mask, float("-inf")).log_softmax( -1) # get decoder inputs nll_label = -logp_label[label_mask].mean() dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # pack attention context if self.config.dec_use_attn: attn_context = enc_outs else: attn_context = None # create decoder initial states dec_init_state = self.connector(enc_last) # decode dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size, dec_inputs=dec_inputs, # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, # my goal, your goal, and the proposal!!! a lot goal_hid=big_goals_h, ) # (batch_size, goal_nhid) if get_marginals: return Pack( dec_outputs=dec_outputs, labels=labels, ) if mode == GEN: return ret_dict, labels if return_latent: return Pack(nll=self.nll(dec_outputs, labels), latent_action=dec_init_state) else: return Pack(nll=self.nll(dec_outputs, labels), nll_label=nll_label)
config = Pack( train_path = '../data/negotiate/train.txt', val_path = '../data/negotiate/val.txt', test_path = '../data/negotiate/test.txt', last_n_model = 5, max_utt_len = 20, backward_size = 14, batch_size = 32, grad_clip=3.0, use_gpu = True, op = 'adam', init_lr = 0.001, l2_norm=0.00001, momentum = 0.0, dropout = 0.5, max_epoch = 50, embed_size = 256, num_layers = 1, utt_rnn_cell = 'gru', utt_cell_size = 128, bi_utt_cell = True, enc_use_attn = False, ctx_rnn_cell = 'gru', ctx_cell_size = 256, bi_ctx_cell = False, y_size = 200, beta = 1.0, simple_posterior=False, use_pr = True, dec_use_attn = False, dec_rnn_cell = 'gru', # must be same as ctx_cell_size due to the passed initial state dec_cell_size = 256, # must be same as ctx_cell_size due to the passed initial state dec_attn_mode = 'cat', # fix_train_batch=False, fix_batch=False, beam_size = 20, avg_type = 'real_word', print_step = 100, ckpt_step = 400, improve_threshold = 0.996, patient_increase = 2.0, save_model = True, early_stop = False, gen_type = 'greedy', preview_batch_num = 1, max_dec_len = 40, k = domain_info.input_length(), goal_embed_size = 64, goal_nhid = 64, init_range = 0.1, pretrain_folder = '2018-11-19-21-28-29-sl_latent', forward_only = False )
def main(): start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[START]', start_time, '=' * 30) # RL configuration folder = '2019-06-20-10-24-23-sl_gauss' epoch_id = '28' env = 'gpu' sim_epoch_id = '23' simulator_folder = '2019-06-20-09-19-39-sl_word' exp_dir = os.path.join('config_log_model', folder, 'rl-' + start_time) if not os.path.exists(exp_dir): os.mkdir(exp_dir) rl_config = Pack( train_path='../data/negotiate/train.txt', val_path='../data/negotiate/val.txt', test_path='../data/negotiate/test.txt', selfplay_path='../data/negotiate/selfplay.txt', selfplay_eval_path='../data/negotiate/selfplay_eval.txt', sim_config_path=os.path.join('config_log_model', simulator_folder, 'config.json'), sim_model_path=os.path.join('config_log_model', simulator_folder, '{}-model'.format(sim_epoch_id)), sv_config_path=os.path.join('config_log_model', folder, 'config.json'), sv_model_path=os.path.join('config_log_model', folder, '{}-model'.format(epoch_id)), rl_config_path=os.path.join(exp_dir, 'rl_config.json'), rl_model_path=os.path.join(exp_dir, 'rl_model'), ppl_best_model_path=os.path.join(exp_dir, 'ppl_best_model'), reward_best_model_path=os.path.join(exp_dir, 'reward_best_model'), judger_model_path=os.path.join('../FB', 'sv_model.th'), judger_config_path=os.path.join('../FB', 'judger_config.json'), record_path=exp_dir, record_freq=50, use_gpu=env == 'gpu', nepoch=4, nepisode=0, sv_train_freq= 0, # TODO pay attention to main.py, cuz it is also controlled there eval_freq=0, max_words=100, rl_lr=0.2, momentum=0.1, nesterov=True, gamma=0.95, rl_clip=1.0, ref_text='../data/negotiate/train.txt', domain='object_division', max_nego_turn=50, random_seed=0, use_latent_rl=True) # save configuration with open(rl_config.rl_config_path, 'w') as f: json.dump(rl_config, f, indent=4) # set random seed set_seed(rl_config.random_seed) # load previous supervised learning configuration and corpus sv_config = Pack(json.load(open(rl_config.sv_config_path))) sim_config = Pack(json.load(open(rl_config.sim_config_path))) # TODO revise the use_gpu in the config sv_config['use_gpu'] = rl_config.use_gpu sim_config['use_gpu'] = rl_config.use_gpu corpus = DealCorpus(sv_config) # load models for two agents # TARGET AGENT sys_model = models_deal.GaussHRED(corpus, sv_config) if sv_config.use_gpu: # TODO gpu -> cpu transfer sys_model.cuda() sys_model.load_state_dict( th.load(rl_config.sv_model_path, map_location=lambda storage, location: storage)) # we don't want to use Dropout during RL sys_model.eval() sys = LatentRlAgent(sys_model, corpus, rl_config, name='System', use_latent_rl=rl_config.use_latent_rl) # SIMULATOR we keep usr frozen, i.e. we don't update its parameters usr_model = models_deal.HRED(corpus, sim_config) if sim_config.use_gpu: # TODO gpu -> cpu transfer usr_model.cuda() usr_model.load_state_dict( th.load(rl_config.sim_model_path, map_location=lambda storage, location: storage)) usr_model.eval() usr_type = LstmAgent usr = usr_type(usr_model, corpus, rl_config, name='User') # load FB judger model # load FB judger model judger_config = Pack(json.load(open(rl_config.judger_config_path))) judger_config['cuda'] = rl_config.use_gpu judger_config['data'] = '../data/negotiate' judger_device_id = FB_use_cuda(judger_config.cuda) judger_word_corpus = FbWordCorpus(judger_config.data, freq_cutoff=judger_config.unk_threshold, verbose=True) judger_model = FbDialogModel(judger_word_corpus.word_dict, judger_word_corpus.item_dict, judger_word_corpus.context_dict, judger_word_corpus.output_length, judger_config, judger_device_id) if judger_device_id is not None: judger_model.cuda(judger_device_id) judger_model.load_state_dict( th.load(rl_config.judger_model_path, map_location=lambda storage, location: storage)) judger_model.eval() judger = Judger(judger_model, judger_device_id) # initialize communication dialogue between two agents dialog = Dialog([sys, usr], judger, rl_config) ctx_gen = ContextGenerator(rl_config.selfplay_path) # simulation module dialog_eval = DialogEval([sys, usr], judger, rl_config) ctx_gen_eval = ContextGeneratorEval(rl_config.selfplay_eval_path) # start RL reinforce = Reinforce(dialog, ctx_gen, corpus, sv_config, sys_model, usr_model, rl_config, dialog_eval, ctx_gen_eval) reinforce.run() # save sys model th.save(sys_model.state_dict(), rl_config.rl_model_path) end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[END]', end_time, '=' * 30)
config = Pack( seed=10, train_path='../data/norm-multi-woz/train_dials.json', valid_path='../data/norm-multi-woz/val_dials.json', test_path='../data/norm-multi-woz/test_dials.json', max_vocab_size=1000, last_n_model=5, max_utt_len=50, max_dec_len=50, backward_size=2, batch_size=32, use_gpu=True, op='adam', init_lr=0.001, l2_norm=1e-05, momentum=0.0, grad_clip=5.0, dropout=0.5, max_epoch=100, embed_size=100, num_layers=1, utt_rnn_cell='gru', utt_cell_size=300, bi_utt_cell=True, enc_use_attn=True, dec_use_attn=True, dec_rnn_cell='lstm', dec_cell_size=300, dec_attn_mode='cat', y_size=10, k_size=20, beta=0.001, simple_posterior=True, contextual_posterior=True, use_mi=False, use_pr=True, use_diversity=False, # beam_size=20, fix_batch=True, fix_train_batch=False, avg_type='word', print_step=300, ckpt_step=1416, improve_threshold=0.996, patient_increase=2.0, save_model=True, early_stop=False, gen_type='greedy', preview_batch_num=None, k=domain_info.input_length(), init_range=0.1, pretrain_folder='2019-06-20-21-43-06-sl_cat', forward_only=False)
def _prepare_batch(self, selected_index): rows = [self.data[idx] for idx in selected_index] ctx_utts, ctx_lens = [], [] out_utts, out_lens = [], [] out_bs, out_db = [], [] goals, goal_lens = [], [[] for _ in range(len(self.domains))] keys = [] for row in rows: in_row, out_row, goal_row = row.context, row.response, row.goal # source context keys.append(row.key) batch_ctx = [] for turn in in_row: batch_ctx.append( self.pad_to(self.max_utt_len, turn.utt, do_pad=True)) ctx_utts.append(batch_ctx) ctx_lens.append(len(batch_ctx)) # target response out_utt = [t for idx, t in enumerate(out_row.utt)] out_utts.append(out_utt) out_lens.append(len(out_utt)) out_bs.append(out_row.bs) out_db.append(out_row.db) # goal goals.append(goal_row) for i, d in enumerate(self.domains): goal_lens[i].append(len(goal_row[d])) batch_size = len(ctx_lens) vec_ctx_lens = np.array(ctx_lens) # (batch_size, ), number of turns max_ctx_len = np.max(vec_ctx_lens) vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32) vec_out_bs = np.array(out_bs) # (batch_size, 94) vec_out_db = np.array(out_db) # (batch_size, 30) vec_out_lens = np.array(out_lens) # (batch_size, ), number of tokens max_out_len = np.max(vec_out_lens) vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32) max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens ], [min(ls) for ls in goal_lens] if max_goal_lens != min_goal_lens: print('Fatal Error!') exit(-1) self.goal_lens = max_goal_lens vec_goals_list = [ np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens ] for b_id in range(batch_size): vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id] vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id] for i, d in enumerate(self.domains): vec_goals_list[i][b_id, :] = goals[b_id][d] return Pack( context_lens=vec_ctx_lens, # (batch_size, ) contexts=vec_ctx_utts, # (batch_size, max_ctx_len, max_utt_len) output_lens=vec_out_lens, # (batch_size, ) outputs=vec_out_utts, # (batch_size, max_out_len) bs=vec_out_bs, # (batch_size, 94) db=vec_out_db, # (batch_size, 30) goals_list= vec_goals_list, # 7*(batch_size, bow_len), bow_len differs w.r.t. domain keys=keys)
def main(): start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[START]', start_time, '=' * 30) # RL configuration env = 'gpu' pretrained_folder = '2019-06-20-22-49-55-sl_cat' pretrained_model_id = 41 exp_dir = os.path.join('sys_config_log_model', pretrained_folder, 'rl-' + start_time) # create exp folder if not os.path.exists(exp_dir): os.mkdir(exp_dir) rl_config = Pack( train_path='../data/norm-multi-woz/train_dials.json', valid_path='../data/norm-multi-woz/val_dials.json', test_path='../data/norm-multi-woz/test_dials.json', sv_config_path=os.path.join('sys_config_log_model', pretrained_folder, 'config.json'), sv_model_path=os.path.join('sys_config_log_model', pretrained_folder, '{}-model'.format(pretrained_model_id)), rl_config_path=os.path.join(exp_dir, 'rl_config.json'), rl_model_path=os.path.join(exp_dir, 'rl_model'), ppl_best_model_path=os.path.join(exp_dir, 'ppl_best.model'), reward_best_model_path=os.path.join(exp_dir, 'reward_best.model'), record_path=exp_dir, record_freq=200, sv_train_freq= 0, # TODO pay attention to main.py, cuz it is also controlled there use_gpu=env == 'gpu', nepoch=10, nepisode=0, tune_pi_only=False, max_words=100, temperature=1.0, episode_repeat=1.0, rl_lr=0.01, momentum=0.0, nesterov=False, gamma=0.99, rl_clip=5.0, random_seed=100, ) # save configuration with open(rl_config.rl_config_path, 'w') as f: json.dump(rl_config, f, indent=4) # set random seed set_seed(rl_config.random_seed) # load previous supervised learning configuration and corpus sv_config = Pack(json.load(open(rl_config.sv_config_path))) sv_config['dropout'] = 0.0 sv_config['use_gpu'] = rl_config.use_gpu corpus = NormMultiWozCorpus(sv_config) # TARGET AGENT sys_model = SysPerfectBD2Cat(corpus, sv_config) if sv_config.use_gpu: sys_model.cuda() sys_model.load_state_dict( th.load(rl_config.sv_model_path, map_location=lambda storage, location: storage)) sys_model.eval() sys = OfflineLatentRlAgent(sys_model, corpus, rl_config, name='System', tune_pi_only=rl_config.tune_pi_only) # start RL reinforce = OfflineTaskReinforce(sys, corpus, sv_config, sys_model, rl_config, task_generate) reinforce.run() end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[END]', end_time, '=' * 30)
def forward( self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False, get_marginals=False, ): ctx_lens = data_feed['context_lens'] # (batch_size, ) ctx_utts = self.np2var(data_feed['contexts'], LONG) # (batch_size, max_ctx_len, max_utt_len) ctx_confs = self.np2var(data_feed['context_confs'], FLOAT) # (batch_size, max_ctx_len) out_utts = self.np2var(data_feed['outputs'], LONG) # (batch_size, max_out_len) goals = self.np2var(data_feed['goals'], LONG) # (batch_size, goal_len) partitions = self.np2var(data_feed.partitions, LONG) num_partitions = self.np2var(data_feed.num_partitions, INT) batch_size = len(ctx_lens) self.z_size = data_feed.num_partitions.max() # oracle parsed_outputs = self.np2var(data_feed.parsed_outputs, LONG) partner_goals = self.np2var(data_feed.true_partner_goals, LONG) # encode goal info goals_h = self.goal_encoder(goals) # (batch_size, goal_nhid) enc_inputs, _, _ = self.utt_encoder( ctx_utts, feats=ctx_confs, goals= goals_h, # (batch_size, max_ctx_len, num_directions*utt_cell_size) ) # enc_outs: (batch_size, max_ctx_len, ctx_cell_size) # enc_last: tuple, (h_n, c_n) # h_n: (num_layers*num_directions, batch_size, ctx_cell_size) # c_n: (num_layers*num_directions, batch_size, ctx_cell_size) enc_outs, enc_last = self.ctx_encoder(enc_inputs, input_lengths=ctx_lens, goals=None) # get decoder inputs dec_inputs = out_utts[:, :-1] labels = out_utts[:, 1:].contiguous() # pack attention context if self.config.dec_use_attn: attn_context = enc_outs else: attn_context = None # create decoder initial states dec_init_state = self.connector(enc_last) # transition matrix ctx_input = self.prior_res_layer(enc_last[-1]) my_state_emb = self.res_layer( th.cat([ self.book_emb(partitions[:, :, 0]), self.hat_emb(partitions[:, :, 1]), self.ball_emb(partitions[:, :, 2]), ], -1)) your_state_emb = self.res_layer( th.cat([ self.book_emb(partitions[:, :, 3]), self.hat_emb(partitions[:, :, 4]), self.ball_emb(partitions[:, :, 5]), ], -1)) state_emb = th.cat([my_state_emb, your_state_emb], -1) my_state_emb_out = self.res_layer_out( th.cat([ self.book_emb_out(partitions[:, :, 0]), self.hat_emb_out(partitions[:, :, 1]), self.ball_emb_out(partitions[:, :, 2]), ], -1)) your_state_emb_out = self.res_layer_out( th.cat([ self.book_emb_out(partitions[:, :, 3]), self.hat_emb_out(partitions[:, :, 4]), self.ball_emb_out(partitions[:, :, 5]), ], -1)) state_emb_out = th.cat([my_state_emb_out, your_state_emb_out], -1) goals_h = self.res_goal_mlp( th.cat([ goals_h.unsqueeze(1).expand( state_emb.shape[0], state_emb.shape[1], goals_h.shape[-1]), state_emb, ], -1)).view(-1, 128) # for noisy labels if self.noisy_proposal_labels: # transition from state to label label_mask = (partitions == parsed_outputs.unsqueeze(1)).all(-1) logp_label_z = th.einsum("nsh,nth->nts", state_emb, state_emb_out).log_softmax(-1) # outer dim t should be output label phi_zt, psi_zl_zr = self.hmm_potentials(state_emb, state_emb_out, ctx_input, lengths=num_partitions) logp_zt = phi_zt.log_softmax(-1) logp_zr_zl = psi_zl_zr.log_softmax(-1).transpose(-1, -2) # decode N, T = dec_inputs.shape dec_init_state = enc_last.repeat(1, 1, self.z_size).view( self.config.num_layers, N * self.z_size, -1) dec_outputs, dec_hidden_state, ret_dict = self.decoder( batch_size=batch_size * self.z_size, dec_inputs=dec_inputs.repeat(1, 1, self.z_size).view( -1, T), # (batch_size, response_size-1) dec_init_state=dec_init_state, # tuple: (h, c) attn_context=None, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type=gen_type, beam_size=self.config.beam_size, goal_hid=goals_h, # (batch_size, goal_nhid) ) BLAM, T, V = dec_outputs.shape # all word probs, they need to be summed over # `log p(xt) = \sum_i \log p(w_ti)` logp_wt_zt = dec_outputs.view(N, self.z_size, T, V).gather( -1, labels.view(N, 1, T, 1).expand(N, self.z_size, T, 1), ).squeeze(-1) # get rid of padding, mask to 0 logp_xt_zt = (logp_wt_zt.masked_fill( labels.unsqueeze(1) == self.nll.padding_idx, 0).sum(-1)) # do linear chain stuff # a little weird, we're working with a chain graphical model # need to normalize over each zt so the lm probs remain normalized dlg_idxs = data_feed.dlg_idxs t = 0 ll_label = 0 prev_zt = logp_zt[t] logp_xt = [(logp_xt_zt[t] + prev_zt).logsumexp(-1)] if self.training and self.noisy_proposal_labels and label_mask[0].any( ): if not self.config.sup_proposal_labels: # predict noisy proposal from hidden state ll_label += (logp_label_z[t] + prev_zt.unsqueeze(-1) ).logsumexp(0)[label_mask[t]].logsumexp(0) else: ll_label += prev_zt[label_mask[t]].logsumexp(0) for t in range(1, N): if dlg_idxs[t] != dlg_idxs[t - 1]: # restart hmm prev_zt = logp_zt[t] logp_xt.append((logp_xt_zt[t] + prev_zt).logsumexp(-1)) else: # continue # unsqueeze is unnecessary, broadcasting handles it prev_zt = (prev_zt.unsqueeze(-2) + logp_zr_zl[t]).logsumexp(-1) #prev_zt = logp_zt[t] logp_xt.append((logp_xt_zt[t] + prev_zt).logsumexp(-1)) if self.training and self.noisy_proposal_labels and label_mask[ t].any(): if not self.config.sup_proposal_labels: # predict noisy proposal from hidden state ll_label += (logp_label_z[t] + prev_zt.unsqueeze(-1) ).logsumexp(0)[label_mask[t]].logsumexp(0) else: ll_label += prev_zt[label_mask[t]].logsumexp(0) logp_xt = th.stack(logp_xt) if self.nll.avg_type == "real_word": nll_word = -(logp_xt / (labels.sign().sum(-1).float())).mean() elif self.nll.avg_type == "word": nll_word = -(logp_xt.sum() / labels.sign().sum()) else: raise ValueError("Unknown reduction type") if self.training and self.noisy_proposal_labels and label_mask.any(): #nll -= 0.1 * ll_label / label_mask.sum().float() nll_label = -self.config.label_weight * ll_label / label_mask.any( -1).sum().float() else: nll_label = th.zeros(1).to(nll_word.device) #import pdb; pdb.set_trace() if get_marginals: return Pack( dec_outputs=dec_outputs, logp_xt=logp_xt, labels=labels, ) #Z = prev_zt.logsumexp(0) if mode == GEN: return ret_dict, labels if return_latent: return Pack(nll=nll, latent_action=dec_init_state) else: return Pack(nll_label=nll_label, nll_word=nll_word)