def generate_with_name(model, data, config): model.eval() de_tknize = get_detokenize() data.epoch_init(config, shuffle=False, verbose=False) logger.info('Generation With Name: {} batches.'.format(data.num_batch)) from collections import defaultdict res = defaultdict(dict) while True: batch = data.next_batch() if batch is None: break keys, outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) pred_labels = [ t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE] ] pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes( 0, 1) # (batch_size, max_dec_len) true_labels = labels.cpu().data.numpy() # (batch_size, output_seq_len) for b_id in range(pred_labels.shape[0]): pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) dlg_name, dlg_turn = keys[b_id] res[dlg_name][dlg_turn] = {'pred': pred_str, 'true': true_str} return res
def model_predict(self, data_feed): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var( self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) # (batch_size, max_ctx_len, max_utt_len) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.model.utt_encoder( short_ctx_utts.unsqueeze(1)) enc_last = torch.cat( [bs_label, db_label, utt_summary.squeeze(1)], dim=1) mode = GEN logits_qy, log_qy = self.model.c2z(enc_last) sample_y = self.model.gumbel_connector(logits_qy, hard=mode == GEN) log_py = self.model.log_uniform_y # pack attention context if self.model.config.dec_use_attn: z_embeddings = torch.t(self.model.z_embedding.weight).split( self.model.k_size, dim=0) attn_context = [] temp_sample_y = sample_y.view(-1, self.model.config.y_size, self.model.config.k_size) for z_id in range(self.model.y_size): attn_context.append( torch.mm(temp_sample_y[:, z_id], z_embeddings[z_id]).unsqueeze(1)) attn_context = torch.cat(attn_context, dim=1) dec_init_state = torch.sum(attn_context, dim=1).unsqueeze(0) else: dec_init_state = self.model.z_embedding( sample_y.view( 1, -1, self.model.config.y_size * self.model.config.k_size)) attn_context = None # decode if self.model.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) dec_outputs, dec_hidden_state, ret_dict = self.model.decoder( batch_size=batch_size, dec_inputs=None, # (batch_size, response_size-1) # tuple: (h, c) dec_init_state=dec_init_state, attn_context=attn_context, # (batch_size, max_ctx_len, ctx_cell_size) mode=mode, gen_type='greedy', beam_size=self.model.config.beam_size) # (batch_size, goal_nhid) # ret_dict['sample_z'] = sample_y # ret_dict['log_qy'] = log_qy pred_labels = [ t.cpu().data.numpy() for t in ret_dict[DecoderRNN.KEY_SEQUENCE] ] pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) de_tknize = get_detokenize() for b_id in range(pred_labels.shape[0]): # only one val for pred_str now pred_str = get_sent(self.model.vocab, de_tknize, pred_labels, b_id) return pred_str
def generate(model, data, config, evaluator, num_batch, dest_f=None): def write(msg): if msg is None or msg == '': return if dest_f is None: print(msg) else: dest_f.write(msg + '\n') model.eval() de_tknize = get_detokenize() data.epoch_init(config, shuffle=num_batch is not None, verbose=False) evaluator.initialize() logger.info('Generation: {} batches'.format( data.num_batch if num_batch is None else num_batch)) batch_cnt = 0 print_cnt = 0 while True: batch_cnt += 1 batch = data.next_batch() if batch is None or (num_batch is not None and data.ptr > num_batch): break outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) # move from GPU to CPU labels = labels.cpu() pred_labels = [ t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE] ] pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes( 0, 1) # (batch_size, max_dec_len) true_labels = labels.data.numpy() # (batch_size, output_seq_len) # get attention if possible if config.dec_use_attn: pred_attns = [ t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE] ] pred_attns = np.array(pred_attns, dtype=float).squeeze(2).swapaxes( 0, 1) # (batch_size, max_dec_len, max_ctx_len) else: pred_attns = None # get context ctx = batch.get('contexts') # (batch_size, max_ctx_len, max_utt_len) ctx_len = batch.get('context_lens') # (batch_size, ) for b_id in range(pred_labels.shape[0]): # TODO attn pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id) true_str = get_sent(model.vocab, de_tknize, true_labels, b_id) prev_ctx = '' if ctx is not None: ctx_str = [] for t_id in range(ctx_len[b_id]): temp_str = get_sent(model.vocab, de_tknize, ctx[:, t_id, :], b_id, stop_eos=False) # print('temp_str = %s' % (temp_str, )) # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], )) ctx_str.append(temp_str) ctx_str = '|'.join(ctx_str)[-200::] prev_ctx = 'Source context: {}'.format(ctx_str) evaluator.add_example(true_str, pred_str) if num_batch is None or batch_cnt < 2: print_cnt += 1 write('prev_ctx = %s' % (prev_ctx, )) write('True: {}'.format(true_str, )) write('Pred: {}'.format(pred_str, )) write('=' * 30) if num_batch is not None and print_cnt > 10: break write(evaluator.get_report()) write('Generation Done')