コード例 #1
0
    def _process_dialogue(self, data):
        new_dlgs = []
        all_sent_lens = []
        all_dlg_lens = []

        for key, raw_dlg in data.items():
            norm_dlg = [Pack(speaker=USR, utt=[BOS, BOD, EOS], bs=[0.0] * self.bs_size, db=[0.0] * self.db_size)]
            for t_id in range(len(raw_dlg['db'])):
                usr_utt = [BOS] + self.tokenize(raw_dlg['usr'][t_id]) + [EOS]
                sys_utt = [BOS] + self.tokenize(raw_dlg['sys'][t_id]) + [EOS]
                norm_dlg.append(Pack(speaker=USR, utt=usr_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
                norm_dlg.append(Pack(speaker=SYS, utt=sys_utt, db=raw_dlg['db'][t_id], bs=raw_dlg['bs'][t_id]))
                all_sent_lens.extend([len(usr_utt), len(sys_utt)])
            # To stop dialog
            norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0] * self.bs_size, db=[0.0] * self.db_size))
            # if self.config.to_learn == 'usr':
            #     norm_dlg.append(Pack(speaker=USR, utt=[BOS, EOD, EOS], bs=[0.0]*self.bs_size, db=[0.0]*self.db_size))
            all_dlg_lens.append(len(raw_dlg['db']))
            processed_goal = self._process_goal(raw_dlg['goal'])
            new_dlgs.append(Pack(dlg=norm_dlg, goal=processed_goal, key=key))

        self.logger.info('Max utt len = %d, mean utt len = %.2f' % (
            np.max(all_sent_lens), float(np.mean(all_sent_lens))))
        self.logger.info('Max dlg len = %d, mean dlg len = %.2f' % (
            np.max(all_dlg_lens), float(np.mean(all_dlg_lens))))
        return new_dlgs
コード例 #2
0
    def forward(self,
                data_feed,
                mode,
                clf=False,
                gen_type='greedy',
                return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(
            self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        out_utts = self.np2var(data_feed['outputs'],
                               LONG)  # (batch_size, max_out_len)
        bs_label = self.np2var(data_feed['bs'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.utt_encoder(
            short_ctx_utts.unsqueeze(1))

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # pack attention context
        if self.config.dec_use_attn:
            attn_context = enc_outs
        else:
            attn_context = None

        # create decoder initial states
        dec_init_state = self.policy(
            th.cat([bs_label, db_label,
                    utt_summary.squeeze(1)], dim=1)).unsqueeze(0)

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            # h_dec_init_state = utt_summary.squeeze(1).unsqueeze(0)
            dec_init_state = tuple([dec_init_state, dec_init_state])

        dec_outputs, dec_hidden_state, ret_dict = self.decoder(
            batch_size=batch_size,
            dec_inputs=dec_inputs,
            # (batch_size, response_size-1)
            dec_init_state=dec_init_state,  # tuple: (h, c)
            attn_context=attn_context,
            # (batch_size, max_ctx_len, ctx_cell_size)
            mode=mode,
            gen_type=gen_type,
            beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
        if mode == GEN:
            return ret_dict, labels
        if return_latent:
            return Pack(nll=self.nll(dec_outputs, labels),
                        latent_action=dec_init_state)
        else:
            return Pack(nll=self.nll(dec_outputs, labels))
コード例 #3
0
 def _to_id_corpus(self, name, data):
     results = []
     for dlg in data:
         if len(dlg.dlg) < 1:
             continue
         id_dlg = []
         for turn in dlg.dlg:
             id_turn = Pack(utt=self._sent2id(turn.utt),
                            speaker=turn.speaker,
                            db=turn.db, bs=turn.bs)
             id_dlg.append(id_turn)
         id_goal = self._goal2id(dlg.goal)
         results.append(Pack(dlg=id_dlg, goal=id_goal, key=dlg.key))
     return results
コード例 #4
0
 def _to_id_corpus(self, name, data):
     results = []
     for dlg in data:
         if len(dlg.dlg) < 1:
             continue
         id_dlg = []
         for turn in dlg.dlg:
             id_turn = Pack(utt=self._sent2id(turn.utt),
                            speaker=turn.speaker)
             id_dlg.append(id_turn)
         id_goal = self._goal2id(dlg.goal)
         id_out = self._outcome2id(dlg.out)
         results.append(Pack(dlg=id_dlg, goal=id_goal, out=id_out))
     return results
コード例 #5
0
 def flatten_dialog(self, data, backward_size):
     results = []
     indexes = []
     batch_indexes = []
     resp_set = set()
     for dlg in data:
         goal = dlg.goal
         key = dlg.key
         batch_index = []
         for i in range(1, len(dlg.dlg)):
             if dlg.dlg[i].speaker == USR:
                 continue
             e_idx = i
             s_idx = max(0, e_idx - backward_size)
             response = dlg.dlg[i].copy()
             response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False)
             resp_set.add(json.dumps(response.utt))
             context = []
             for turn in dlg.dlg[s_idx: e_idx]:
                 turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False)
                 context.append(turn)
             results.append(Pack(context=context, response=response, goal=goal, key=key))
             indexes.append(len(indexes))
             batch_index.append(indexes[-1])
         if len(batch_index) > 0:
             batch_indexes.append(batch_index)
     print("Unique resp {}".format(len(resp_set)))
     return results, indexes, batch_indexes
コード例 #6
0
    def forward(self, data_feed, mode, clf=False, gen_type='greedy', use_py=None, return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        out_utts = self.np2var(data_feed['outputs'], LONG)  # (batch_size, max_out_len)
        bs_label = self.np2var(data_feed['bs'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'], FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1))

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # create decoder initial states
        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)

        # create decoder initial states
        if self.simple_posterior:
            q_mu, q_logvar = self.c2z(enc_last)
            sample_z = self.gauss_connector(q_mu, q_logvar)
            p_mu, p_logvar = self.zero, self.zero
        else:
            p_mu, p_logvar = self.c2z(enc_last)
            # encode response and use posterior to find q(z|x, c)
            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
            q_mu, q_logvar = self.xc2z(th.cat([enc_last, x_h.squeeze(1)], dim=1))

            # use prior at inference time, otherwise use posterior
            if mode == GEN or use_py:
                sample_z = self.gauss_connector(p_mu, p_logvar)
            else:
                sample_z = self.gauss_connector(q_mu, q_logvar)

        # pack attention context
        dec_init_state = self.z_embedding(sample_z.unsqueeze(0))
        attn_context = None

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        dec_outputs, dec_hidden_state, ret_dict = self.decoder(batch_size=batch_size,
                                                               dec_inputs=dec_inputs,
                                                               dec_init_state=dec_init_state,  # tuple: (h, c)
                                                               attn_context=attn_context,
                                                               mode=mode,
                                                               gen_type=gen_type,
                                                               beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
        if mode == GEN:
            ret_dict['sample_z'] = sample_z
            return ret_dict, labels

        else:
            result = Pack(nll=self.nll(dec_outputs, labels))
            pi_kl = self.gauss_kl(q_mu, q_logvar, p_mu, p_logvar)
            result['pi_kl'] = pi_kl
            result['nll'] = self.nll(dec_outputs, labels)
            return result
コード例 #7
0
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]

        ctx_utts, ctx_lens = [], []
        out_utts, out_lens = [], []
        goals, goal_lens = [], []

        for row in rows:
            in_row, out_row, goal_row = row.context, row.response, row.goal

            # source context
            batch_ctx = []
            for turn in in_row:
                batch_ctx.append(
                    self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
            ctx_utts.append(batch_ctx)
            ctx_lens.append(len(batch_ctx))

            # target response
            out_utt = [t for idx, t in enumerate(out_row.utt)]
            out_utts.append(out_utt)
            out_lens.append(len(out_utt))

            # goal
            goals.append(goal_row)
            goal_lens.append(len(goal_row))

        vec_ctx_lens = np.array(ctx_lens)  # (batch_size, ), number of turns
        max_ctx_len = np.max(vec_ctx_lens)
        vec_ctx_utts = np.zeros(
            (self.batch_size, max_ctx_len, self.max_utt_len), dtype=np.int32)
        # confs is used to add some hand-crafted features
        vec_ctx_confs = np.ones((self.batch_size, max_ctx_len),
                                dtype=np.float32)
        vec_out_lens = np.array(out_lens)  # (batch_size, ), number of tokens
        max_out_len = np.max(vec_out_lens)
        vec_out_utts = np.zeros((self.batch_size, max_out_len), dtype=np.int32)

        max_goal_len, min_goal_len = max(goal_lens), min(goal_lens)
        if max_goal_len != min_goal_len or max_goal_len != 6:
            print('FATAL ERROR!')
            exit(-1)
        self.goal_len = max_goal_len
        vec_goals = np.zeros((self.batch_size, self.goal_len), dtype=np.int32)

        for b_id in range(self.batch_size):
            vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
            vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
            vec_goals[b_id, :] = goals[b_id]

        return Pack(context_lens=vec_ctx_lens, \
                    contexts=vec_ctx_utts, \
                    context_confs=vec_ctx_confs, \
                    output_lens=vec_out_lens, \
                    outputs=vec_out_utts, \
                    goals=vec_goals)
コード例 #8
0
        def transform(token_list):
            usr, sys = [], []
            ptr = 0
            while ptr < len(token_list):
                turn_ptr = ptr
                turn_list = []
                while True:
                    cur_token = token_list[turn_ptr]
                    turn_list.append(cur_token)
                    turn_ptr += 1
                    if cur_token == EOS:
                        ptr = turn_ptr
                        break
                all_sent_lens.append(len(turn_list))
                if turn_list[0] == USR:
                    usr.append(Pack(utt=turn_list, speaker=USR))
                elif turn_list[0] == SYS:
                    sys.append(Pack(utt=turn_list, speaker=SYS))
                else:
                    raise ValueError('Invalid speaker')

            all_dlg_lens.append(len(usr) + len(sys))
            return usr, sys
コード例 #9
0
 def flatten_dialog(self, data, backward_size):
     results = []
     for dlg in data:
         goal = dlg.goal
         for i in range(1, len(dlg.dlg)):
             if dlg.dlg[i].speaker == USR:
                 continue
             e_idx = i
             s_idx = max(0, e_idx - backward_size)
             response = dlg.dlg[i].copy()
             response['utt'] = self.pad_to(self.max_utt_len, response.utt, do_pad=False)
             context = []
             for turn in dlg.dlg[s_idx: e_idx]:
                 turn['utt'] = self.pad_to(self.max_utt_len, turn.utt, do_pad=False)
                 context.append(turn)
             results.append(Pack(context=context, response=response, goal=goal))
     return results
コード例 #10
0
def prepare_batch_gen(rows, config):
    domains = [
        'hotel', 'restaurant', 'train', 'attraction', 'hospital', 'police',
        'taxi'
    ]

    ctx_utts, ctx_lens = [], []
    out_utts, out_lens = [], []

    out_bs, out_db = [], []
    goals, goal_lens = [], [[] for _ in range(len(domains))]
    keys = []

    for row in rows:
        in_row, out_row = row['context'], row['response']

        # source context
        batch_ctx = []
        for turn in in_row:
            batch_ctx.append(
                pad_to(config.max_utt_len, turn['utt'], do_pad=True))
        ctx_utts.append(batch_ctx)
        ctx_lens.append(len(batch_ctx))

        out_bs.append(out_row['bs'])
        out_db.append(out_row['db'])

    batch_size = len(ctx_lens)
    vec_ctx_lens = np.array(ctx_lens)  # (batch_size, ), number of turns
    max_ctx_len = np.max(vec_ctx_lens)
    vec_ctx_utts = np.zeros((batch_size, max_ctx_len, config.max_utt_len),
                            dtype=np.int32)
    vec_out_bs = np.array(out_bs)  # (batch_size, 94)
    vec_out_db = np.array(out_db)  # (batch_size, 30)

    for b_id in range(batch_size):
        vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]

    return Pack(
        context_lens=vec_ctx_lens,  # (batch_size, )
        # (batch_size, max_ctx_len, max_utt_len)
        contexts=vec_ctx_utts,
        bs=vec_out_bs,  # (batch_size, 94)
        db=vec_out_db  # (batch_size, 30)
    )
コード例 #11
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 cuda_device=DEFAULT_CUDA_DEVICE,
                 model_file=None):
        SysPolicy.__init__(self)

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for LaRL is specified!")
            archive_file = cached_path(model_file)

        temp_path = tempfile.mkdtemp()
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.prev_state = init_state()
        self.prev_active_domain = None

        domain_name = 'object_division'
        domain_info = domain.get_domain(domain_name)

        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
        train_data_path = os.path.join(data_path, 'norm-multi-woz', 'train_dials.json')
        if not os.path.exists(train_data_path):
            zipped_file = os.path.join(data_path, 'norm-multi-woz.zip')
            archive = zipfile.ZipFile(zipped_file, 'r')
            archive.extractall(data_path)

        norm_multiwoz_path = os.path.join(data_path, 'norm-multi-woz')
        with open(os.path.join(norm_multiwoz_path, 'input_lang.index2word.json')) as f:
            self.input_lang_index2word = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'input_lang.word2index.json')) as f:
            self.input_lang_word2index = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'output_lang.index2word.json')) as f:
            self.output_lang_index2word = json.load(f)
        with open(os.path.join(norm_multiwoz_path, 'output_lang.word2index.json')) as f:
            self.output_lang_word2index = json.load(f)

        config = Pack(
            seed=10,
            train_path=train_data_path,
            max_vocab_size=1000,
            last_n_model=5,
            max_utt_len=50,
            max_dec_len=50,
            backward_size=2,
            batch_size=1,
            use_gpu=True,
            op='adam',
            init_lr=0.001,
            l2_norm=1e-05,
            momentum=0.0,
            grad_clip=5.0,
            dropout=0.5,
            max_epoch=100,
            embed_size=100,
            num_layers=1,
            utt_rnn_cell='gru',
            utt_cell_size=300,
            bi_utt_cell=True,
            enc_use_attn=True,
            dec_use_attn=True,
            dec_rnn_cell='lstm',
            dec_cell_size=300,
            dec_attn_mode='cat',
            y_size=10,
            k_size=20,
            beta=0.001,
            simple_posterior=True,
            contextual_posterior=True,
            use_mi=False,
            use_pr=True,
            use_diversity=False,
            #
            beam_size=20,
            fix_batch=True,
            fix_train_batch=False,
            avg_type='word',
            print_step=300,
            ckpt_step=1416,
            improve_threshold=0.996,
            patient_increase=2.0,
            save_model=True,
            early_stop=False,
            gen_type='greedy',
            preview_batch_num=None,
            k=domain_info.input_length(),
            init_range=0.1,
            pretrain_folder='2019-09-20-21-43-06-sl_cat',
            forward_only=False
        )

        config.use_gpu = config.use_gpu and torch.cuda.is_available()
        self.corpus = corpora_inference.NormMultiWozCorpus(config)
        self.model = SysPerfectBD2Cat(self.corpus, config)
        self.config = config
        if config.use_gpu:
            self.model.load_state_dict(torch.load(
                os.path.join(temp_path, 'larl_model/best-model')))
            self.model.cuda()
        else:
            self.model.load_state_dict(torch.load(os.path.join(
                temp_path, 'larl_model/best-model'), map_location=lambda storage, loc: storage))
        self.model.eval()
        self.dic = pickle.load(
            open(os.path.join(temp_path, 'larl_model/svdic.pkl'), 'rb'))
コード例 #12
0
    def _prepare_batch(self, selected_index):
        rows = [self.data[idx] for idx in selected_index]

        ctx_utts, ctx_lens = [], []
        out_utts, out_lens = [], []

        out_bs, out_db = [], []
        goals, goal_lens = [], [[] for _ in range(len(self.domains))]
        keys = []

        for row in rows:
            in_row, out_row, goal_row = row.context, row.response, row.goal

            # source context
            keys.append(row.key)
            batch_ctx = []
            for turn in in_row:
                batch_ctx.append(
                    self.pad_to(self.max_utt_len, turn.utt, do_pad=True))
            ctx_utts.append(batch_ctx)
            ctx_lens.append(len(batch_ctx))

            # target response
            out_utt = [t for idx, t in enumerate(out_row.utt)]
            out_utts.append(out_utt)
            out_lens.append(len(out_utt))

            out_bs.append(out_row.bs)
            out_db.append(out_row.db)

            # goal
            goals.append(goal_row)
            for i, d in enumerate(self.domains):
                goal_lens[i].append(len(goal_row[d]))

        batch_size = len(ctx_lens)
        vec_ctx_lens = np.array(ctx_lens)  # (batch_size, ), number of turns
        max_ctx_len = np.max(vec_ctx_lens)
        vec_ctx_utts = np.zeros((batch_size, max_ctx_len, self.max_utt_len),
                                dtype=np.int32)
        vec_out_bs = np.array(out_bs)  # (batch_size, 94)
        vec_out_db = np.array(out_db)  # (batch_size, 30)
        vec_out_lens = np.array(out_lens)  # (batch_size, ), number of tokens
        max_out_len = np.max(vec_out_lens)
        vec_out_utts = np.zeros((batch_size, max_out_len), dtype=np.int32)

        max_goal_lens, min_goal_lens = [max(ls) for ls in goal_lens
                                        ], [min(ls) for ls in goal_lens]
        if max_goal_lens != min_goal_lens:
            print('Fatal Error!')
            exit(-1)
        self.goal_lens = max_goal_lens
        vec_goals_list = [
            np.zeros((batch_size, l), dtype=np.float32) for l in self.goal_lens
        ]

        for b_id in range(batch_size):
            vec_ctx_utts[b_id, :vec_ctx_lens[b_id], :] = ctx_utts[b_id]
            vec_out_utts[b_id, :vec_out_lens[b_id]] = out_utts[b_id]
            for i, d in enumerate(self.domains):
                vec_goals_list[i][b_id, :] = goals[b_id][d]

        return Pack(
            context_lens=vec_ctx_lens,  # (batch_size, )
            contexts=vec_ctx_utts,  # (batch_size, max_ctx_len, max_utt_len)
            output_lens=vec_out_lens,  # (batch_size, )
            outputs=vec_out_utts,  # (batch_size, max_out_len)
            bs=vec_out_bs,  # (batch_size, 94)
            db=vec_out_db,  # (batch_size, 30)
            goals_list=
            vec_goals_list,  # 7*(batch_size, bow_len), bow_len differs w.r.t. domain
            keys=keys)
コード例 #13
0
    def _process_dialogue(self, data):
        def transform(token_list):
            usr, sys = [], []
            ptr = 0
            while ptr < len(token_list):
                turn_ptr = ptr
                turn_list = []
                while True:
                    cur_token = token_list[turn_ptr]
                    turn_list.append(cur_token)
                    turn_ptr += 1
                    if cur_token == EOS:
                        ptr = turn_ptr
                        break
                all_sent_lens.append(len(turn_list))
                if turn_list[0] == USR:
                    usr.append(Pack(utt=turn_list, speaker=USR))
                elif turn_list[0] == SYS:
                    sys.append(Pack(utt=turn_list, speaker=SYS))
                else:
                    raise ValueError('Invalid speaker')

            all_dlg_lens.append(len(usr) + len(sys))
            return usr, sys

        new_dlg = []
        all_sent_lens = []
        all_dlg_lens = []
        for raw_dlg in data:
            raw_words = raw_dlg.split()

            # process dialogue text
            cur_dlg = []
            words = raw_words[raw_words.index('<dialogue>') +
                              1:raw_words.index('</dialogue>')]
            words += [EOS]
            usr_first = True
            if words[0] == SYS:
                words = [USR, BOD, EOS] + words
                usr_first = True
            elif words[0] == USR:
                words = [SYS, BOD, EOS] + words
                usr_first = False
            else:
                print('FATAL ERROR!!! ({})'.format(words))
                exit(-1)
            usr_utts, sys_utts = transform(words)
            for usr_turn, sys_turn in zip(usr_utts, sys_utts):
                if usr_first:
                    cur_dlg.append(usr_turn)
                    cur_dlg.append(sys_turn)
                else:
                    cur_dlg.append(sys_turn)
                    cur_dlg.append(usr_turn)
            if len(usr_utts) - len(sys_utts) == 1:
                cur_dlg.append(usr_utts[-1])
            elif len(sys_utts) - len(usr_utts) == 1:
                cur_dlg.append(sys_utts[-1])

            # process goal (6 digits)
            # FIXME FATAL ERROR HERE !!!
            cur_goal = raw_words[raw_words.index('<partner_input>') +
                                 1:raw_words.index('</partner_input>')]
            # cur_goal = raw_words[raw_words.index('<input>')+1: raw_words.index('</input>')]
            if len(cur_goal) != 6:
                print('FATAL ERROR!!! ({})'.format(cur_goal))
                exit(-1)

            # process outcome (6 tokens)
            cur_out = raw_words[raw_words.index('<output>') +
                                1:raw_words.index('</output>')]
            if len(cur_out) != 6:
                print('FATAL ERROR!!! ({})'.format(cur_out))
                exit(-1)

            new_dlg.append(Pack(dlg=cur_dlg, goal=cur_goal, out=cur_out))

        print('Max utt len = %d, mean utt len = %.2f' %
              (np.max(all_sent_lens), float(np.mean(all_sent_lens))))
        print('Max dlg len = %d, mean dlg len = %.2f' %
              (np.max(all_dlg_lens), float(np.mean(all_dlg_lens))))
        return new_dlg
コード例 #14
0
def main():

    start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                               time.localtime(time.time()))
    print('[START]', start_time, '=' * 30)

    env = 'gpu'
    pretrained_folder = '2018-11-13-21-27-21-sys_sl_bdu2resp'
    pretrained_model_id = 61
    exp_dir = os.path.join('sys_config_log_model', pretrained_folder,
                           'rl-' + start_time)
    # create exp folder
    if not os.path.exists(exp_dir):
        os.mkdir(exp_dir)

    # RL configuration
    rl_config = Pack(
        train_path='../data/norm-multi-woz/train_dials.json',
        valid_path='../data/norm-multi-woz/val_dials.json',
        test_path='../data/norm-multi-woz/test_dials.json',
        sv_config_path=os.path.join('sys_config_log_model', pretrained_folder,
                                    'config.json'),
        sv_model_path=os.path.join('sys_config_log_model', pretrained_folder,
                                   '{}-model'.format(pretrained_model_id)),
        rl_config_path=os.path.join(exp_dir, 'rl_config.json'),
        rl_model_path=os.path.join(exp_dir, 'rl_model'),
        ppl_best_model_path=os.path.join(exp_dir, 'ppl_best.model'),
        reward_best_model_path=os.path.join(exp_dir, 'reward_best.model'),
        record_path=exp_dir,
        record_freq=200,
        sv_train_freq=
        1000,  # TODO pay attention to main.py, cuz it is also controlled there
        use_gpu=env == 'gpu',
        nepoch=10,
        nepisode=0,
        max_words=100,
        episode_repeat=1.0,
        temperature=1.0,
        rl_lr=0.01,
        momentum=0.0,
        nesterov=False,
        gamma=0.99,
        rl_clip=5.0,
        random_seed=10,
    )

    # save configuration
    with open(rl_config.rl_config_path, 'w') as f:
        json.dump(rl_config, f, indent=4)

    # set random seed
    set_seed(rl_config.random_seed)

    # load previous supervised learning configuration and corpus
    sv_config = Pack(json.load(open(rl_config.sv_config_path)))

    sv_config['use_gpu'] = rl_config.use_gpu
    corpus = NormMultiWozCorpus(sv_config)

    # TARGET AGENT
    sys_model = SysPerfectBD2Word(corpus, sv_config)
    if sv_config.use_gpu:
        sys_model.cuda()
    sys_model.load_state_dict(
        th.load(rl_config.sv_model_path,
                map_location=lambda storage, location: storage))
    sys_model.eval()
    sys = OfflineRlAgent(sys_model,
                         corpus,
                         rl_config,
                         name='System',
                         tune_pi_only=False)

    # start RL
    reinforce = OfflineTaskReinforce(sys, corpus, sv_config, sys_model,
                                     rl_config, task_generate)
    reinforce.run()

    # save sys model
    th.save(sys_model.state_dict(), rl_config.rl_model_path)

    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    print('[END]', end_time, '=' * 30)
コード例 #15
0
    def forward(self,
                data_feed,
                mode,
                clf=False,
                gen_type='greedy',
                use_py=None,
                return_latent=False):
        ctx_lens = data_feed['context_lens']  # (batch_size, )
        short_ctx_utts = self.np2var(
            self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG)
        out_utts = self.np2var(data_feed['outputs'],
                               LONG)  # (batch_size, max_out_len)
        bs_label = self.np2var(data_feed['bs'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        db_label = self.np2var(data_feed['db'],
                               FLOAT)  # (batch_size, max_ctx_len, max_utt_len)
        batch_size = len(ctx_lens)

        utt_summary, _, enc_outs = self.utt_encoder(
            short_ctx_utts.unsqueeze(1))

        # get decoder inputs
        dec_inputs = out_utts[:, :-1]
        labels = out_utts[:, 1:].contiguous()

        # create decoder initial states
        enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1)
        # create decoder initial states
        if self.simple_posterior:
            logits_qy, log_qy = self.c2z(enc_last)
            sample_y = self.gumbel_connector(logits_qy, hard=mode == GEN)
            log_py = self.log_uniform_y
        else:
            logits_py, log_py = self.c2z(enc_last)
            # encode response and use posterior to find q(z|x, c)
            x_h, _, _ = self.utt_encoder(out_utts.unsqueeze(1))
            if self.contextual_posterior:
                logits_qy, log_qy = self.xc2z(
                    th.cat([enc_last, x_h.squeeze(1)], dim=1))
            else:
                logits_qy, log_qy = self.xc2z(x_h.squeeze(1))

            # use prior at inference time, otherwise use posterior
            if mode == GEN or (use_py is not None and use_py is True):
                sample_y = self.gumbel_connector(logits_py, hard=False)
            else:
                sample_y = self.gumbel_connector(logits_qy, hard=True)

        # pack attention context
        if self.config.dec_use_attn:
            z_embeddings = th.t(self.z_embedding.weight).split(self.k_size,
                                                               dim=0)
            attn_context = []
            temp_sample_y = sample_y.view(-1, self.config.y_size,
                                          self.config.k_size)
            for z_id in range(self.y_size):
                attn_context.append(
                    th.mm(temp_sample_y[:, z_id],
                          z_embeddings[z_id]).unsqueeze(1))
            attn_context = th.cat(attn_context, dim=1)
            dec_init_state = th.sum(attn_context, dim=1).unsqueeze(0)
        else:
            dec_init_state = self.z_embedding(
                sample_y.view(1, -1, self.config.y_size * self.config.k_size))
            attn_context = None

        # decode
        if self.config.dec_rnn_cell == 'lstm':
            dec_init_state = tuple([dec_init_state, dec_init_state])

        dec_outputs, dec_hidden_state, ret_dict = self.decoder(
            batch_size=batch_size,
            dec_inputs=dec_inputs,
            # (batch_size, response_size-1)
            dec_init_state=dec_init_state,  # tuple: (h, c)
            attn_context=attn_context,
            # (batch_size, max_ctx_len, ctx_cell_size)
            mode=mode,
            gen_type=gen_type,
            beam_size=self.config.beam_size)  # (batch_size, goal_nhid)
        if mode == GEN:
            ret_dict['sample_z'] = sample_y
            ret_dict['log_qy'] = log_qy
            return ret_dict, labels

        else:
            result = Pack(nll=self.nll(dec_outputs, labels))
            # regularization qy to be uniform
            avg_log_qy = th.exp(
                log_qy.view(-1, self.config.y_size, self.config.k_size))
            avg_log_qy = th.log(th.mean(avg_log_qy, dim=0) + 1e-15)
            b_pr = self.cat_kl_loss(avg_log_qy,
                                    self.log_uniform_y,
                                    batch_size,
                                    unit_average=True)
            mi = self.entropy_loss(avg_log_qy,
                                   unit_average=True) - self.entropy_loss(
                                       log_qy, unit_average=True)
            pi_kl = self.cat_kl_loss(log_qy,
                                     log_py,
                                     batch_size,
                                     unit_average=True)
            q_y = th.exp(log_qy).view(-1, self.config.y_size,
                                      self.config.k_size)  # b
            p = th.pow(th.bmm(q_y, th.transpose(q_y, 1, 2)) - self.eye, 2)

            result['pi_kl'] = pi_kl

            result['diversity'] = th.mean(p)
            result['nll'] = self.nll(dec_outputs, labels)
            result['b_pr'] = b_pr
            result['mi'] = mi
            return result
コード例 #16
0
config = Pack(
    random_seed=10,
    train_path='../data/norm-multi-woz/train_dials.json',
    valid_path='../data/norm-multi-woz/val_dials.json',
    test_path='../data/norm-multi-woz/test_dials.json',
    max_vocab_size=1000,
    max_utt_len=50,
    max_dec_len=50,
    last_n_model=5,
    backward_size=2,
    batch_size=32,
    use_gpu=True,
    op='adam',
    init_lr=0.001,
    l2_norm=1e-05,
    momentum=0.0,
    grad_clip=5.0,
    dropout=0.5,
    max_epoch=100,
    embed_size=100,
    num_layers=1,
    utt_rnn_cell='gru',
    utt_cell_size=300,
    bi_utt_cell=True,
    enc_use_attn=True,
    dec_use_attn=False,
    dec_rnn_cell='lstm',
    # must be same as ctx_cell_size due to the passed initial state
    dec_cell_size=300,
    # must be same as ctx_cell_size due to the passed initial state
    dec_attn_mode='cat',
    #
    beam_size=20,
    fix_batch=True,
    fix_train_batch=False,
    avg_type='word',
    print_step=500,
    ckpt_step=1771,
    improve_threshold=0.996,
    patient_increase=2.0,
    save_model=True,
    early_stop=False,
    gen_type='greedy',
    preview_batch_num=None,
    k=domain_info.input_length(),
    init_range=0.1,
    pretrain_folder='2018-11-13-21-27-21-sys_sl_bdu2resp',
    forward_only=False)
コード例 #17
0
 config = Pack(
     seed=10,
     train_path=train_data_path,
     max_vocab_size=1000,
     last_n_model=5,
     max_utt_len=50,
     max_dec_len=50,
     backward_size=2,
     batch_size=1,
     use_gpu=True,
     op='adam',
     init_lr=0.001,
     l2_norm=1e-05,
     momentum=0.0,
     grad_clip=5.0,
     dropout=0.5,
     max_epoch=100,
     embed_size=100,
     num_layers=1,
     utt_rnn_cell='gru',
     utt_cell_size=300,
     bi_utt_cell=True,
     enc_use_attn=True,
     dec_use_attn=True,
     dec_rnn_cell='lstm',
     dec_cell_size=300,
     dec_attn_mode='cat',
     y_size=10,
     k_size=20,
     beta=0.001,
     simple_posterior=True,
     contextual_posterior=True,
     use_mi=False,
     use_pr=True,
     use_diversity=False,
     #
     beam_size=20,
     fix_batch=True,
     fix_train_batch=False,
     avg_type='word',
     print_step=300,
     ckpt_step=1416,
     improve_threshold=0.996,
     patient_increase=2.0,
     save_model=True,
     early_stop=False,
     gen_type='greedy',
     preview_batch_num=None,
     k=domain_info.input_length(),
     init_range=0.1,
     pretrain_folder='2019-09-20-21-43-06-sl_cat',
     forward_only=False
 )
コード例 #18
0
    def predict_response(self, state):
        history = []
        for i in range(len(state['history'])):
            for j in range(len(state['history'][i])):
                history.append(state['history'][i][j])

        e_idx = len(history)
        s_idx = max(0, e_idx - self.config.backward_size)
        context = []
        for turn in history[s_idx: e_idx]:
            # turn = pad_to(config.max_utt_len, turn, do_pad=False)
            context.append(turn)

        if len(state['history']) == 1:
            self.prev_state = init_state()

        prepared_data = {}
        prepared_data['context'] = []
        prepared_data['response'] = {}

        prev_bstate = deepcopy(self.prev_state['belief_state'])
        state_history = state['history']
        bstate = deepcopy(state['belief_state'])

        # mark_not_mentioned(prev_state)
        active_domain = self.get_active_domain(
            self.prev_active_domain, prev_bstate, bstate)
        domain_mark_not_mentioned(bstate, active_domain)

        top_results, num_results = None, None
        for usr in context:
            words = usr.split()

            usr = delexicalize.delexicalise(' '.join(words), self.dic)

            # parsing reference number GIVEN belief state
            usr = delexicaliseReferenceNumber(usr, bstate)

            # changes to numbers only here
            digitpat = re.compile('\d+')
            usr = re.sub(digitpat, '[value_count]', usr)
            # add database pointer
            pointer_vector, top_results, num_results = addDBPointer(bstate)
            # add booking pointer
            pointer_vector = addBookingPointer(bstate, pointer_vector)
            belief_summary = get_summary_bstate(bstate)

            usr_utt = [BOS] + usr.split() + [EOS]
            packed_val = {}
            packed_val['bs'] = belief_summary
            packed_val['db'] = pointer_vector
            packed_val['utt'] = self.corpus._sent2id(usr_utt)

            prepared_data['context'].append(packed_val)

        prepared_data['response']['bs'] = prepared_data['context'][-1]['bs']
        prepared_data['response']['db'] = prepared_data['context'][-1]['db']
        results = [Pack(context=prepared_data['context'],
                        response=prepared_data['response'])]

        data_feed = prepare_batch_gen(results, self.config)

        outputs = self.model_predict(data_feed)

        if active_domain is not None and active_domain in num_results:
            num_results = num_results[active_domain]
        else:
            num_results = 0

        if active_domain is not None and active_domain in top_results:
            top_results = {active_domain: top_results[active_domain]}
        else:
            top_results = {}

        state_with_history = deepcopy(bstate)
        state_with_history['history'] = deepcopy(state_history)
        response = self.populate_template(
            outputs, top_results, num_results, state_with_history)
        import pprint
        pprint.pprint("============")
        pprint.pprint('usr:'******'agent:')
        pprint.pprint(response)
        pprint.pprint("============")

        return response, active_domain