class Agent(object): """ RL Agent for online GUI It's the initial version, and may have some bugs. Refer to the guiActiveLearning.py for the latest version. """ def __init__(self, args, sess): self.env_act = Environment(args, 'act') self.net_act = DeepQLearner(args, 'act', 'channels_first') # self.net_act = DeepQLearner(args, sess, 'act') # for tensorflow self.env_arg = Environment(args, 'arg') self.net_arg = DeepQLearner(args, 'arg', 'channels_first') # self.net_arg = DeepQLearner(args, sess, 'arg') # for tensorflow self.num_words = args.num_words self.context_len = args.context_len def predict(self, text): # e.g. text = ['Cook the rice the day before.', 'Use leftover rice.'] self.env_act.init_predict_act_text(text) # act_seq = [] sents = [] for i in range(len(self.env_act.current_text['sents'])): last_sent = self.env_act.current_text['sents'][i - 1] if i > 0 else [] this_sent = self.env_act.current_text['sents'][i] sents.append({'last_sent': last_sent, 'this_sent': this_sent, 'acts': []}) # ipdb.set_trace() for i in range(self.num_words): state_act = self.env_act.getState() qvalues_act = self.net_act.predict(state_act) action_act = np.argmax(qvalues_act[0]) self.env_act.act_online(action_act, i) if action_act == 1: last_sent, this_sent = self.env_arg.init_predict_arg_text(i, self.env_act.current_text) for j in range(self.context_len): state_arg = self.env_arg.getState() qvalues_arg = self.net_arg.predict(state_arg) action_arg = np.argmax(qvalues_arg[0]) self.env_arg.act_online(action_arg, j) if self.env_arg.terminal_flag: break # act_name = self.env_act.current_text['tokens'][i] # act_arg = [act_name] act_idx = i obj_idxs = [] sent_words = self.env_arg.current_text['tokens'] tmp_num = self.context_len if len(sent_words) >= self.context_len else len(sent_words) for j in range(tmp_num): if self.env_arg.state[j, -1] == 2: #act_arg.append(sent_words[j]) if j == len(sent_words) - 1: j = -1 obj_idxs.append(j) if len(obj_idxs) == 0: # act_arg.append(sent_words[-1]) obj_idxs.append(-1) # ipdb.set_trace() si, ai = self.env_act.current_text['word2sent'][i] ai += len(sents[si]['last_sent']) sents[si]['acts'].append({'act_idx': ai, 'obj_idxs': [obj_idxs, []], 'act_type': 1, 'related_acts': []}) # act_seq.append(act_arg) if self.env_act.terminal_flag: break return sents
class Agent(object): """ RL Agent for online Active Learning """ def __init__(self, args, sess): self.env_act = Environment(args, 'act') # self.net_act = DeepQLearner(args, 'act', 'channels_first') self.net_act = DeepQLearner(args, sess, 'act') # for tensorflow self.env_arg = Environment(args, 'arg') # self.net_arg = DeepQLearner(args, 'arg', 'channels_first') self.net_arg = DeepQLearner(args, sess, 'arg') # for tensorflow self.num_words = args.num_words self.context_len = args.context_len self.gamma = args.gamma self.uncertainty_mode = 'cml' # or 'cml' def load_data(self): """ Load all unlabeled texts. PS: the file 'home_and_garden_500_words_with_title.pkl' contains more than 15k unlabeled texts from wikihow home and garden category. """ print( 'Loading texts from data/home_and_garden_500_words_with_title.pkl ...' ) self.texts = load_pkl('data/home_and_garden_500_words_with_title.pkl') self.label2text = text_classification() self.history_texts = [] self.sort_ind = 1 if self.uncertainty_mode == 'diff' else 0 self.category = 0 # category of the currently chosen text self.max_category = len(self.label2text) - 1 def choose_unlabeled_texts(self, num_texts, dialog=None): """ Apply Active Learning. Choose texts from each class and sort them by cumulative reward. """ chosen_texts = [] while len(chosen_texts) < num_texts: # text_ind = np.random.randint(len(self.texts)) # text = self.texts[text_ind] text_ind = random.sample(self.label2text[self.category], 1)[0] if text_ind in self.history_texts: # or len(text['title'].split()) < 2: continue # print('textID: {:<10} category: {}'.format(text_ind, self.category)) # traverse all categories, choose texts from each category self.category = self.category + 1 if self.category < self.max_category else 0 # predict Q-values, compute cumulative reward text = self.texts[text_ind] sents, word2sent, R_t = self.predict(text['sent']) r_t = R_t[:-1] - self.gamma * R_t[ 1:] # deduced from R_t = r_t + gamma * R_{t+1} cml_rwd = sum(r_t) + self.gamma * R_t[-1] delta_r = abs( R_t[0] - cml_rwd) # difference between predicted and real cml_rwd text['sents'] = sents text['reward'] = (cml_rwd, delta_r) # text['r_t'] = r_t # len(r_t) = len(words) - 1 text['text_ind'] = text_ind text['word2sent'] = word2sent chosen_texts.append(text) if dialog: dialog.Update( len(chosen_texts), 'Progress: %d/%d' % (len(chosen_texts), num_texts)) # sort the texts by cumulative reward sorted_texts = sorted(chosen_texts, key=lambda x: x['reward'][self.sort_ind]) # for t in sorted_texts: # print(t['text_ind'], t['reward'][self.sort_ind]) # print('\n') return sorted_texts def predict(self, text): """ Call EASDRL model to generate output actions for an input text e.g. text = ['Cook the rice the day before.', 'Use leftover rice.'] """ self.env_act.init_predict_act_text(text) # act_seq = [] sents = [] for i in range(len(self.env_act.current_text['sents'])): if i > 0: last_sent = self.env_act.current_text['sents'][i - 1] # last_pos = self.env_act.current_text['sent_pos'][i - 1] else: last_sent = [] # last_pos = [] this_sent = self.env_act.current_text['sents'][i] # this_pos = self.env_act.current_text['sent_pos'][i] sents.append({ 'last_sent': last_sent, 'this_sent': this_sent, 'acts': [] }) #, #'last_pos': last_pos, 'this_pos': this_pos}) word2sent = self.env_act.current_text['word2sent'] # ipdb.set_trace() R_t = [] for i in range(self.num_words): state_act = self.env_act.getState() qvalues_act = self.net_act.predict(state_act) R_t.append(max(qvalues_act[0])) action_act = np.argmax(qvalues_act[0]) self.env_act.act_online(action_act, i) if action_act == 1: last_sent, this_sent = self.env_arg.init_predict_arg_text( i, self.env_act.current_text) for j in range(self.context_len): state_arg = self.env_arg.getState() qvalues_arg = self.net_arg.predict(state_arg) action_arg = np.argmax(qvalues_arg[0]) self.env_arg.act_online(action_arg, j) if self.env_arg.terminal_flag: break # act_name = self.env_act.current_text['tokens'][i] # act_arg = [act_name] act_idx = i obj_idxs = [] sent_words = self.env_arg.current_text['tokens'] tmp_num = self.context_len if len( sent_words) >= self.context_len else len(sent_words) for j in range(tmp_num): if self.env_arg.state[j, -1] == 2: #act_arg.append(sent_words[j]) if j == len(sent_words) - 1: j = -1 obj_idxs.append(j) if len(obj_idxs) == 0: # act_arg.append(sent_words[-1]) obj_idxs.append(-1) # ipdb.set_trace() si, ai = self.env_act.current_text['word2sent'][i] ai += len(sents[si]['last_sent']) sents[si]['acts'].append({ 'act_idx': ai, 'obj_idxs': [obj_idxs, []], 'act_type': 1, 'related_acts': [] }) # act_seq.append(act_arg) if self.env_act.terminal_flag: break # for k, v in act_seq.iteritems(): # print(k, v) # ipdb.set_trace() return sents, word2sent, np.array(R_t)