Exemple #1
0
class Agent(object):
    """
    RL Agent for online GUI
    It's the initial version, and may have some bugs.
    Refer to the guiActiveLearning.py for the latest version.
    """
    def __init__(self, args, sess):
        self.env_act = Environment(args, 'act')
        self.net_act = DeepQLearner(args, 'act', 'channels_first')
        # self.net_act = DeepQLearner(args, sess, 'act') # for tensorflow
        self.env_arg = Environment(args, 'arg')
        self.net_arg = DeepQLearner(args, 'arg', 'channels_first')
        # self.net_arg = DeepQLearner(args, sess, 'arg') # for tensorflow
        self.num_words = args.num_words
        self.context_len = args.context_len


    def predict(self, text):
        # e.g. text = ['Cook the rice the day before.', 'Use leftover rice.']
        self.env_act.init_predict_act_text(text)
        # act_seq = []
        sents = []
        for i in range(len(self.env_act.current_text['sents'])):
            last_sent = self.env_act.current_text['sents'][i - 1] if i > 0 else []
            this_sent = self.env_act.current_text['sents'][i]
            sents.append({'last_sent': last_sent, 'this_sent': this_sent, 'acts': []})
        # ipdb.set_trace()
        for i in range(self.num_words):
            state_act = self.env_act.getState()
            qvalues_act = self.net_act.predict(state_act)
            action_act = np.argmax(qvalues_act[0])
            self.env_act.act_online(action_act, i)
            if action_act == 1:
                last_sent, this_sent = self.env_arg.init_predict_arg_text(i, self.env_act.current_text)
                for j in range(self.context_len):
                    state_arg = self.env_arg.getState()
                    qvalues_arg = self.net_arg.predict(state_arg)
                    action_arg = np.argmax(qvalues_arg[0])
                    self.env_arg.act_online(action_arg, j)
                    if self.env_arg.terminal_flag:
                        break
                # act_name = self.env_act.current_text['tokens'][i]
                # act_arg = [act_name]
                act_idx = i
                obj_idxs = []
                sent_words = self.env_arg.current_text['tokens']
                tmp_num = self.context_len if len(sent_words) >= self.context_len else len(sent_words)
                for j in range(tmp_num):
                    if self.env_arg.state[j, -1] == 2:
                        #act_arg.append(sent_words[j])
                        if j == len(sent_words) - 1:
                            j = -1
                        obj_idxs.append(j)
                if len(obj_idxs) == 0:
                    # act_arg.append(sent_words[-1])
                    obj_idxs.append(-1)
                # ipdb.set_trace()
                si, ai = self.env_act.current_text['word2sent'][i]
                ai += len(sents[si]['last_sent'])
                sents[si]['acts'].append({'act_idx': ai, 'obj_idxs': [obj_idxs, []],
                                            'act_type': 1, 'related_acts': []})
                # act_seq.append(act_arg)
            if self.env_act.terminal_flag:
                break
        return sents
class Agent(object):
    """
    RL Agent for online Active Learning
    """
    def __init__(self, args, sess):
        self.env_act = Environment(args, 'act')
        # self.net_act = DeepQLearner(args, 'act', 'channels_first')
        self.net_act = DeepQLearner(args, sess, 'act')  # for tensorflow
        self.env_arg = Environment(args, 'arg')
        # self.net_arg = DeepQLearner(args, 'arg', 'channels_first')
        self.net_arg = DeepQLearner(args, sess, 'arg')  # for tensorflow
        self.num_words = args.num_words
        self.context_len = args.context_len
        self.gamma = args.gamma
        self.uncertainty_mode = 'cml'  # or 'cml'

    def load_data(self):
        """
        Load all unlabeled texts.
        PS: the file 'home_and_garden_500_words_with_title.pkl' contains more than 15k 
            unlabeled texts from wikihow home and garden category.
        """
        print(
            'Loading texts from data/home_and_garden_500_words_with_title.pkl ...'
        )
        self.texts = load_pkl('data/home_and_garden_500_words_with_title.pkl')
        self.label2text = text_classification()
        self.history_texts = []
        self.sort_ind = 1 if self.uncertainty_mode == 'diff' else 0
        self.category = 0  # category of the currently chosen text
        self.max_category = len(self.label2text) - 1

    def choose_unlabeled_texts(self, num_texts, dialog=None):
        """
        Apply Active Learning. 
        Choose texts from each class and sort them by cumulative reward.
        """
        chosen_texts = []
        while len(chosen_texts) < num_texts:
            # text_ind = np.random.randint(len(self.texts))
            # text = self.texts[text_ind]
            text_ind = random.sample(self.label2text[self.category], 1)[0]
            if text_ind in self.history_texts:  # or len(text['title'].split()) < 2:
                continue

            # print('textID: {:<10}  category: {}'.format(text_ind, self.category))
            # traverse all categories, choose texts from each category
            self.category = self.category + 1 if self.category < self.max_category else 0
            # predict Q-values, compute cumulative reward
            text = self.texts[text_ind]
            sents, word2sent, R_t = self.predict(text['sent'])
            r_t = R_t[:-1] - self.gamma * R_t[
                1:]  # deduced from R_t = r_t + gamma * R_{t+1}
            cml_rwd = sum(r_t) + self.gamma * R_t[-1]
            delta_r = abs(
                R_t[0] -
                cml_rwd)  # difference between predicted and real cml_rwd
            text['sents'] = sents
            text['reward'] = (cml_rwd, delta_r)  #
            text['r_t'] = r_t  # len(r_t) = len(words) - 1
            text['text_ind'] = text_ind
            text['word2sent'] = word2sent
            chosen_texts.append(text)
            if dialog:
                dialog.Update(
                    len(chosen_texts),
                    'Progress: %d/%d' % (len(chosen_texts), num_texts))
        # sort the texts by cumulative reward
        sorted_texts = sorted(chosen_texts,
                              key=lambda x: x['reward'][self.sort_ind])
        # for t in sorted_texts:
        #     print(t['text_ind'], t['reward'][self.sort_ind])
        # print('\n')
        return sorted_texts

    def predict(self, text):
        """
        Call EASDRL model to generate output actions for an input text
        e.g. text = ['Cook the rice the day before.', 'Use leftover rice.']
        """
        self.env_act.init_predict_act_text(text)
        # act_seq = []
        sents = []
        for i in range(len(self.env_act.current_text['sents'])):
            if i > 0:
                last_sent = self.env_act.current_text['sents'][i - 1]
                # last_pos = self.env_act.current_text['sent_pos'][i - 1]
            else:
                last_sent = []
                # last_pos = []
            this_sent = self.env_act.current_text['sents'][i]
            # this_pos = self.env_act.current_text['sent_pos'][i]
            sents.append({
                'last_sent': last_sent,
                'this_sent': this_sent,
                'acts': []
            })  #,
            #'last_pos': last_pos, 'this_pos': this_pos})
        word2sent = self.env_act.current_text['word2sent']
        # ipdb.set_trace()
        R_t = []
        for i in range(self.num_words):
            state_act = self.env_act.getState()
            qvalues_act = self.net_act.predict(state_act)
            R_t.append(max(qvalues_act[0]))
            action_act = np.argmax(qvalues_act[0])
            self.env_act.act_online(action_act, i)
            if action_act == 1:
                last_sent, this_sent = self.env_arg.init_predict_arg_text(
                    i, self.env_act.current_text)
                for j in range(self.context_len):
                    state_arg = self.env_arg.getState()
                    qvalues_arg = self.net_arg.predict(state_arg)
                    action_arg = np.argmax(qvalues_arg[0])
                    self.env_arg.act_online(action_arg, j)
                    if self.env_arg.terminal_flag:
                        break
                # act_name = self.env_act.current_text['tokens'][i]
                # act_arg = [act_name]
                act_idx = i
                obj_idxs = []
                sent_words = self.env_arg.current_text['tokens']
                tmp_num = self.context_len if len(
                    sent_words) >= self.context_len else len(sent_words)
                for j in range(tmp_num):
                    if self.env_arg.state[j, -1] == 2:
                        #act_arg.append(sent_words[j])
                        if j == len(sent_words) - 1:
                            j = -1
                        obj_idxs.append(j)
                if len(obj_idxs) == 0:
                    # act_arg.append(sent_words[-1])
                    obj_idxs.append(-1)
                # ipdb.set_trace()
                si, ai = self.env_act.current_text['word2sent'][i]
                ai += len(sents[si]['last_sent'])
                sents[si]['acts'].append({
                    'act_idx': ai,
                    'obj_idxs': [obj_idxs, []],
                    'act_type': 1,
                    'related_acts': []
                })
                # act_seq.append(act_arg)
            if self.env_act.terminal_flag:
                break
        # for k, v in act_seq.iteritems():
        #     print(k, v)
        # ipdb.set_trace()
        return sents, word2sent, np.array(R_t)