示例#1
0
def generate_with_name(model, data, config):
    model.eval()
    de_tknize = get_detokenize()
    data.epoch_init(config, shuffle=False, verbose=False)
    logger.info('Generation With Name: {} batches.'.format(data.num_batch))

    from collections import defaultdict
    res = defaultdict(dict)
    while True:
        batch = data.next_batch()
        if batch is None:
            break
        keys, outputs, labels = model(batch,
                                      mode=GEN,
                                      gen_type=config.gen_type)

        pred_labels = [
            t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]
        ]
        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(
            0, 1)  # (batch_size, max_dec_len)
        true_labels = labels.cpu().data.numpy()  # (batch_size, output_seq_len)

        for b_id in range(pred_labels.shape[0]):
            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
            dlg_name, dlg_turn = keys[b_id]
            res[dlg_name][dlg_turn] = {'pred': pred_str, 'true': true_str}

    return res
示例#2
0
    def model_predict(self, data_feed):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(
            self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        # (batch_size, max_ctx_len, max_utt_len)
        bs_label = self.np2var(data_feed['bs'], FLOAT)
        # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'], FLOAT)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.model.utt_encoder(
            short_ctx_utts.unsqueeze(1))

        enc_last = torch.cat(
            [bs_label, db_label, utt_summary.squeeze(1)], dim=1)

        mode = GEN

        logits_qy, log_qy = self.model.c2z(enc_last)
        sample_y = self.model.gumbel_connector(logits_qy, hard=mode == GEN)
        log_py = self.model.log_uniform_y

        # pack attention context
        if self.model.config.dec_use_attn:
            z_embeddings = torch.t(self.model.z_embedding.weight).split(
                self.model.k_size, dim=0)
            attn_context = []
            temp_sample_y = sample_y.view(-1, self.model.config.y_size,
                                          self.model.config.k_size)
            for z_id in range(self.model.y_size):
                attn_context.append(
                    torch.mm(temp_sample_y[:, z_id],
                             z_embeddings[z_id]).unsqueeze(1))
            attn_context = torch.cat(attn_context, dim=1)
            dec_init_state = torch.sum(attn_context, dim=1).unsqueeze(0)
        else:
            dec_init_state = self.model.z_embedding(
                sample_y.view(
                    1, -1,
                    self.model.config.y_size * self.model.config.k_size))
            attn_context = None

        # decode
        if self.model.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        dec_outputs, dec_hidden_state, ret_dict = self.model.decoder(
            batch_size=batch_size,
            dec_inputs=None,
            # (batch_size, response_size-1)
            # tuple: (h, c)
            dec_init_state=dec_init_state,
            attn_context=attn_context,
            # (batch_size, max_ctx_len, ctx_cell_size)
            mode=mode,
            gen_type='greedy',
            beam_size=self.model.config.beam_size)  # (batch_size, goal_nhid)

        # ret_dict['sample_z'] = sample_y
        # ret_dict['log_qy'] = log_qy

        pred_labels = [
            t.cpu().data.numpy() for t in ret_dict[DecoderRNN.KEY_SEQUENCE]
        ]
        pred_labels = np.array(pred_labels,
                               dtype=int).squeeze(-1).swapaxes(0, 1)
        de_tknize = get_detokenize()
        for b_id in range(pred_labels.shape[0]):
            # only one val for pred_str now
            pred_str = get_sent(self.model.vocab, de_tknize, pred_labels, b_id)

            return pred_str
示例#3
0
def generate(model, data, config, evaluator, num_batch, dest_f=None):
    def write(msg):
        if msg is None or msg == '':
            return
        if dest_f is None:
            print(msg)
        else:
            dest_f.write(msg + '\n')

    model.eval()
    de_tknize = get_detokenize()
    data.epoch_init(config, shuffle=num_batch is not None, verbose=False)
    evaluator.initialize()
    logger.info('Generation: {} batches'.format(
        data.num_batch if num_batch is None else num_batch))
    batch_cnt = 0
    print_cnt = 0
    while True:
        batch_cnt += 1
        batch = data.next_batch()
        if batch is None or (num_batch is not None and data.ptr > num_batch):
            break
        outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type)

        # move from GPU to CPU
        labels = labels.cpu()
        pred_labels = [
            t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE]
        ]
        pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(
            0, 1)  # (batch_size, max_dec_len)
        true_labels = labels.data.numpy()  # (batch_size, output_seq_len)

        # get attention if possible
        if config.dec_use_attn:
            pred_attns = [
                t.cpu().data.numpy()
                for t in outputs[DecoderRNN.KEY_ATTN_SCORE]
            ]
            pred_attns = np.array(pred_attns, dtype=float).squeeze(2).swapaxes(
                0, 1)  # (batch_size, max_dec_len, max_ctx_len)
        else:
            pred_attns = None
        # get context
        ctx = batch.get('contexts')  # (batch_size, max_ctx_len, max_utt_len)
        ctx_len = batch.get('context_lens')  # (batch_size, )

        for b_id in range(pred_labels.shape[0]):
            # TODO attn
            pred_str = get_sent(model.vocab, de_tknize, pred_labels, b_id)
            true_str = get_sent(model.vocab, de_tknize, true_labels, b_id)
            prev_ctx = ''
            if ctx is not None:
                ctx_str = []
                for t_id in range(ctx_len[b_id]):
                    temp_str = get_sent(model.vocab,
                                        de_tknize,
                                        ctx[:, t_id, :],
                                        b_id,
                                        stop_eos=False)
                    # print('temp_str = %s' % (temp_str, ))
                    # print('ctx[:, t_id, :] = %s' % (ctx[:, t_id, :], ))
                    ctx_str.append(temp_str)
                ctx_str = '|'.join(ctx_str)[-200::]
                prev_ctx = 'Source context: {}'.format(ctx_str)

            evaluator.add_example(true_str, pred_str)

            if num_batch is None or batch_cnt < 2:
                print_cnt += 1
                write('prev_ctx = %s' % (prev_ctx, ))
                write('True: {}'.format(true_str, ))
                write('Pred: {}'.format(pred_str, ))
                write('=' * 30)
                if num_batch is not None and print_cnt > 10:
                    break

    write(evaluator.get_report())
    write('Generation Done')