class Dialog(object): def __init__(self, agents, args): # For now we only suppport dialog of 2 agents assert len(agents) == 2 self.agents = agents self.args = args self.domain = domain.get_domain(args.domain) self.metrics = MetricsContainer() self._register_metrics() def _register_metrics(self): self.metrics.register_average('dialog_len') self.metrics.register_average('sent_len') self.metrics.register_percentage('agree') self.metrics.register_moving_percentage('moving_agree') self.metrics.register_average('advantage') self.metrics.register_moving_average('moving_advantage') self.metrics.register_time('time') self.metrics.register_average('comb_rew') self.metrics.register_average('agree_comb_rew') for agent in self.agents: self.metrics.register_average('%s_rew' % agent.name) self.metrics.register_moving_average('%s_moving_rew' % agent.name) self.metrics.register_average('agree_%s_rew' % agent.name) self.metrics.register_percentage('%s_sel' % agent.name) self.metrics.register_uniqueness('%s_unique' % agent.name) # text metrics if self.args.ref_text: ref_text = ' '.join(data.read_lines(self.args.ref_text)) self.metrics.register_ngram('full_match', text=ref_text) def _is_selection(self, out): """if dialog end""" return len(out) == 1 and (out[0] in ['<selection>', '<no_agreement>']) def show_metrics(self): return ' '.join( ['%s=%s' % (k, v) for k, v in self.metrics.dict().items()]) def run(self, ctxs, logger, max_words=5000): assert len(self.agents) == len(ctxs) for agent, ctx, partner_ctx in zip(self.agents, ctxs, reversed(ctxs)): agent.feed_context(ctx) agent.feed_partner_context(partner_ctx) logger.dump_ctx(agent.name, ctx) logger.dump('-' * 80) # Choose who goes first by random if np.random.rand() < 0.5: writer, reader = self.agents else: reader, writer = self.agents conv = [] self.metrics.reset() #words_left = np.random.randint(50, 200) words_left = max_words # max 5000 words length = 0 expired = False turn_num = 0 while True: # print('dialog turn [{}]'.format(turn_num)) turn_num += 1 # print('\twrite') out = writer.write(max_words=20) #words_left) # print('\twrite done') words_left -= len(out) length += len(out) self.metrics.record('sent_len', len(out)) if 'full_match' in self.metrics.metrics: self.metrics.record('full_match', out) self.metrics.record('%s_unique' % writer.name, out) conv.append(out) # print('\tread') reader.read(out) if not writer.human: logger.dump_sent(writer.name, out) if self._is_selection(out): self.metrics.record('%s_sel' % writer.name, 1) self.metrics.record('%s_sel' % reader.name, 0) break if words_left <= 1: break writer, reader = reader, writer # print('turn_num:{}'.format(turn_num)) choices = [] for agent in self.agents: choice = agent.choose() choices.append(choice) logger.dump_choice(agent.name, choice[:self.domain.selection_length() // 2]) agree, rewards = self.domain.score_choices(choices, ctxs) if expired: agree = False logger.dump('-' * 80) logger.dump_agreement(agree) for i, (agent, reward) in enumerate(zip(self.agents, rewards)): logger.dump_reward(agent.name, agree, reward) j = 1 if i == 0 else 0 agent.update(agree, reward, choice=choices[i], partner_choice=choices[j], partner_input=ctxs[j], partner_reward=rewards[j]) if agree: self.metrics.record('advantage', rewards[0] - rewards[1]) self.metrics.record('moving_advantage', rewards[0] - rewards[1]) self.metrics.record('agree_comb_rew', np.sum(rewards)) for agent, reward in zip(self.agents, rewards): self.metrics.record('agree_%s_rew' % agent.name, reward) self.metrics.record('time') self.metrics.record('dialog_len', len(conv)) self.metrics.record('agree', int(agree)) self.metrics.record('moving_agree', int(agree)) self.metrics.record('comb_rew', np.sum(rewards) if agree else 0) for agent, reward in zip(self.agents, rewards): self.metrics.record('%s_rew' % agent.name, reward if agree else 0) self.metrics.record('%s_moving_rew' % agent.name, reward if agree else 0) logger.dump('-' * 80) logger.dump(self.show_metrics()) logger.dump('-' * 80) for ctx, choice in zip(ctxs, choices): logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice))) return conv, agree, rewards
class Dialog(object): def __init__(self, agents, args, markable_detector, markable_detector_corpus): # For now we only suppport dialog of 2 agents assert len(agents) == 2 self.agents = agents self.args = args self.domain = domain.get_domain(args.domain) self.metrics = MetricsContainer() self._register_metrics() self.markable_detector = markable_detector self.markable_detector_corpus = markable_detector_corpus self.selfplay_markables = {} self.selfplay_referents = {} def _register_metrics(self): self.metrics.register_average('dialog_len') self.metrics.register_average('sent_len') self.metrics.register_percentage('agree') self.metrics.register_moving_percentage('moving_agree') self.metrics.register_average('advantage') self.metrics.register_moving_average('moving_advantage') self.metrics.register_time('time') self.metrics.register_average('comb_rew') self.metrics.register_average('agree_comb_rew') for agent in self.agents: self.metrics.register_average('%s_rew' % agent.name) self.metrics.register_moving_average('%s_moving_rew' % agent.name) self.metrics.register_average('agree_%s_rew' % agent.name) self.metrics.register_percentage('%s_make_sel' % agent.name) self.metrics.register_uniqueness('%s_unique' % agent.name) if "plot_metrics" in self.args and self.args.plot_metrics: self.metrics.register_select_frequency('%s_sel_bias' % agent.name) # text metrics if self.args.ref_text: ref_text = ' '.join(data.read_lines(self.args.ref_text)) self.metrics.register_ngram('full_match', text=ref_text) def _is_selection(self, out): return '<selection>' in out def show_metrics(self): return ' '.join( ['%s=%s' % (k, v) for k, v in self.metrics.dict().items()]) def plot_metrics(self): self.metrics.plot() def run(self, ctxs, logger, max_words=5000): scenario_id = ctxs[0][0] for agent, agent_id, ctx, real_ids in zip(self.agents, [0, 1], ctxs[1], ctxs[2]): agent.feed_context(ctx) agent.real_ids = real_ids agent.agent_id = agent_id # Choose who goes first by random if np.random.rand() < 0.5: writer, reader = self.agents else: reader, writer = self.agents conv = [] speaker = [] self.metrics.reset() words_left = max_words length = 0 expired = False while True: out = writer.write(max_words=words_left) words_left -= len(out) length += len(out) self.metrics.record('sent_len', len(out)) if 'full_match' in self.metrics.metrics: self.metrics.record('full_match', out) self.metrics.record('%s_unique' % writer.name, out) conv.append(out) speaker.append(writer.agent_id) reader.read(out) if not writer.human: logger.dump_sent(writer.name, out) if logger.scenarios and self.args.log_attention: attention = writer.get_attention() if attention is not None: logger.dump_attention(writer.name, writer.agent_id, scenario_id, attention) if self._is_selection(out): self.metrics.record('%s_make_sel' % writer.name, 1) self.metrics.record('%s_make_sel' % reader.name, 0) break if words_left <= 1: break writer, reader = reader, writer choices = [] for agent in self.agents: choice = agent.choose() choices.append(choice) if logger.scenarios: logger.dump_choice(scenario_id, choices) if "plot_metrics" in self.args and self.args.plot_metrics: for agent in [0, 1]: for obj in logger.scenarios[scenario_id]['kbs'][agent]: if obj['id'] == choices[agent]: self.metrics.record( '%s_sel_bias' % writer.name, obj, logger.scenarios[scenario_id]['kbs'][agent]) agree, rewards = self.domain.score_choices(choices, ctxs) if expired: agree = False logger.dump('-' * 80) logger.dump_agreement(agree) for i, (agent, reward) in enumerate(zip(self.agents, rewards)): j = 1 if i == 0 else 0 agent.update(agree, reward, choice=choices[i]) if agree: self.metrics.record('advantage', rewards[0] - rewards[1]) self.metrics.record('moving_advantage', rewards[0] - rewards[1]) self.metrics.record('agree_comb_rew', np.sum(rewards)) for agent, reward in zip(self.agents, rewards): self.metrics.record('agree_%s_rew' % agent.name, reward) self.metrics.record('time') self.metrics.record('dialog_len', len(conv)) self.metrics.record('agree', int(agree)) self.metrics.record('moving_agree', int(agree)) self.metrics.record('comb_rew', np.sum(rewards) if agree else 0) for agent, reward in zip(self.agents, rewards): self.metrics.record('%s_rew' % agent.name, reward if agree else 0) self.metrics.record('%s_moving_rew' % agent.name, reward if agree else 0) if self.markable_detector is not None and self.markable_detector_corpus is not None: markable_list = [] referents_dict = {} markable_starts = [] for agent in [0, 1]: dialog_tokens = [] dialog_text = "" markables = [] for spkr, uttr in zip(speaker, conv): if spkr == agent: dialog_tokens.append("YOU:") else: dialog_tokens.append("THEM:") dialog_tokens += uttr dialog_text += str(spkr) + ": " + " ".join( uttr[:-1]) + "\n" words = self.markable_detector_corpus.word_dict.w2i( dialog_tokens) words = torch.Tensor(words).long().cuda() score, tag_seq = self.markable_detector(words) referent_inpt = [] markable_ids = [] my_utterance = None current_text = "" for i, word in enumerate(words): if word.item( ) == self.markable_detector_corpus.word_dict.word2idx[ "YOU:"]: my_utterance = True current_speaker = agent elif word.item( ) == self.markable_detector_corpus.word_dict.word2idx[ "THEM:"]: my_utterance = False current_speaker = 1 - agent if my_utterance: if tag_seq[i].item( ) == self.markable_detector_corpus.bio_dict["B"]: start_idx = i for j in range(i + 1, len(tag_seq)): if tag_seq[j].item( ) != self.markable_detector_corpus.bio_dict[ "I"]: end_idx = j - 1 break for j in range(i + 1, len(tag_seq)): if tag_seq[j].item( ) in self.markable_detector_corpus.word_dict.w2i( ["<eos>", "<selection>"]): end_uttr = j break markable_start = len(current_text + " ") if markable_start not in markable_starts: referent_inpt.append( [start_idx, end_idx, end_uttr]) markable_ids.append(len(markable_starts)) # add markable markable = {} markable["start"] = markable_start markable["end"] = len( current_text + " " + " ".join( dialog_tokens[start_idx:end_idx + 1])) #markable["start"] = len(str(spkr) + ": " + " ".join(dialog_tokens[1:start_idx]) + " ") #markable["end"] = len(str(spkr) + ": " + " ".join(dialog_tokens[1:end_idx + 1])) markable["markable_id"] = len( markable_starts) markable["speaker"] = current_speaker markable["text"] = " ".join( dialog_tokens[start_idx:end_idx + 1]) markable_starts.append(markable["start"]) markable_list.append(markable) if word.item( ) == self.markable_detector_corpus.word_dict.word2idx[ "YOU:"]: current_text += "{}:".format(current_speaker) elif word.item( ) == self.markable_detector_corpus.word_dict.word2idx[ "THEM:"]: current_text += "{}:".format(current_speaker) elif word.item( ) in self.markable_detector_corpus.word_dict.w2i( ["<eos>", "<selection>"]): current_text += "\n" else: current_text += " " + self.markable_detector_corpus.word_dict.idx2word[ word.item()] assert len(current_text) == len(dialog_text) ref_out = self.agents[agent].predict_referents( referent_inpt) if ref_out is not None: for i, markable_id in enumerate(markable_ids): ent_ids = [ ent["id"] for ent in logger.scenarios[scenario_id]['kbs'][agent] ] referents = [] for j, is_referent in enumerate( (ref_out[i] > 0).tolist()): if is_referent: referents.append("agent_" + str(agent) + "_" + ent_ids[j]) referents_dict[markable_id] = referents #markable_starts = list(set(markable_starts)) # reindex markable ids markable_id_and_start = [ (markable_id, markable_start) for markable_id, markable_start in zip( range(len(markable_starts)), markable_starts) ] reindexed_markable_ids = [ markable_id for markable_id, _ in sorted(markable_id_and_start, key=lambda x: x[1]) ] self.selfplay_markables[scenario_id] = {} self.selfplay_referents[scenario_id] = {} # add markables self.selfplay_markables[scenario_id]["markables"] = [] for new_markable_id, old_markable_id in enumerate( reindexed_markable_ids): markable = markable_list[old_markable_id] markable["markable_id"] = "M{}".format(new_markable_id + 1) self.selfplay_markables[scenario_id]["markables"].append( markable) # add dialogue text self.selfplay_markables[scenario_id]["text"] = dialog_text # add final selections self.selfplay_markables[scenario_id]["selections"] = choices # add referents for new_markable_id, old_markable_id in enumerate( reindexed_markable_ids): referents = referents_dict[old_markable_id] self.selfplay_referents[scenario_id]["M{}".format( new_markable_id + 1)] = referents logger.dump('-' * 80) logger.dump(self.show_metrics()) logger.dump('-' * 80) #for ctx, choice in zip(ctxs, choices): # logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice))) return conv, agree, rewards