def setup_episodes(self, fold): """ Parses into TodStructuredEpisode. """ domains = self.opt.get("taskmaster2_domains", DOMAINS) chunks, ontologies = self._load_data(fold, domains) domains_cnt = Counter() episodes = [] for _, row in chunks.iterrows(): domains_cnt[row["domain"]] += 1 utterances = row["utterances"][:] idx = 0 rounds = [] goal_calls = [] if len(utterances) > 0 and utterances[0]["speaker"] == "ASSISTANT": idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( "ASSISTANT", utterances, idx) r = tod.TodStructuredRound(api_resp_machine=api_resp, sys_utt=sys_utt) rounds.append(r) cum_api_call = {} while idx < len(utterances): idx, user_utt, api_call = self._get_utterance_and_api_call_for_speaker( "USER", utterances, idx) idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( "ASSISTANT", utterances, idx) if not self.opt["use_cumulative_api_calls"]: r = tod.TodStructuredRound( user_utt=user_utt, api_call_machine=api_call, api_resp_machine=api_resp, sys_utt=sys_utt, ) else: cum_api_call = self.process_call_for_cumlative_standalone_api( api_call, cum_api_call) r = tod.TodStructuredRound( user_utt=user_utt, api_call_machine=cum_api_call if len(api_resp) > 0 else {}, api_resp_machine=api_resp if len(api_resp) > 0 else {}, sys_utt=sys_utt, ) rounds.append(r) if len(api_call) > 0: goal_calls.append(api_call) episode = tod.TodStructuredEpisode( domain=tod.SerializationHelpers.inner_list_join(row["domain"]), api_schemas_machine=self._get_onto_list( ontologies, row["domain"]), goal_calls_machine=goal_calls, rounds=rounds, delex=self.opt.get("delex", False), ) episodes.append(episode) return episodes
def setup_episodes(self, fold): """ Parses Google SGD episodes into TodStructuredEpisode. """ schema_lookup, dialogues = self._load_data(fold) result = [] for dialogue in dialogues: domains = {s.split("_")[0].strip() for s in dialogue["services"]} turns = dialogue["turns"] rounds = [] for turn_id in range(0, len(turns), 2): user_turn = turns[turn_id] sys_turn = turns[turn_id + 1] api_call, api_results = self._get_api_call_and_results( sys_turn) r = tod.TodStructuredRound( user_utt=user_turn["utterance"], api_call_machine=api_call, api_resp_machine=api_results, sys_utt=sys_turn["utterance"], ) rounds.append(r) # Now that we've got the rounds, make the episode episode = tod.TodStructuredEpisode( domain=SerializationHelpers.inner_list_join(domains), api_schemas_machine=self._get_intent_groundinging( schema_lookup, set(dialogue["services"])), goal_calls_machine=self._get_all_service_calls(turns), rounds=rounds, delex=self.opt.get("delex"), extras={"dialogue_id": dialogue["dialogue_id"]}, ) result.append(episode) # check if the number of episodes should be limited and truncate as required return result
def _get_turns_from_parsed(self, user_utt, api_calls, api_resps, sys_utt): result = [ tod.TodStructuredRound( user_utt=user_utt, api_call_machine=api_calls, api_resp_machine=api_resps, sys_utt=sys_utt, ) ] return result
def get_rounds(episode_idx, max_rounds, use_broken_mock_api_calls=False): return [ tod_core.TodStructuredRound( user_utt=f"user_utt_{episode_idx}_{round_idx}", api_call_machine=make_api_call_machine(round_idx, episode_idx, use_broken_mock_api_calls), api_resp_machine=make_api_resp_machine(round_idx), sys_utt=f"sys_utt_{episode_idx}_{round_idx}", ) for round_idx in range(max_rounds) ]
def _process_line(self, line): blob = json.loads(line) if "dialog" not in blob or len(blob["dialog"]) < 1: return None rounds = [] for raw_round in blob["dialog"][1:]: if "prefix_stripped_text" not in raw_round[0]: for i in range(len(raw_round)): if (PREFIXES[i] not in raw_round[i]['text'] and self.opt["fail_hard"]): raise RuntimeError( f"Missing prefix `{PREFIXES[i]}` before turn of text: `{raw_round[i]}`" ) raw_round[i]["prefix_stripped_text"] = raw_round[i].get( "text", PREFIXES[i])[len(PREFIXES[i]):] if len(raw_round) != 4: if raw_round[0]["prefix_stripped_text"] != tod.STANDARD_DONE: return None # misformatted convo, don't learn this. break # TodStructuredEpisode will add in [DONE] api_call_machine = tod.SerializationHelpers.str_to_api_dict( raw_round[1]["prefix_stripped_text"]) if (len(api_call_machine) > 0 and tod.STANDARD_API_NAME_SLOT not in api_call_machine): raise RuntimeError( f"Trying to make API call without `{tod.STANDARD_API_NAME_SLOT}`. Call is: `{raw_round[1]['text']}`" ) r = tod.TodStructuredRound( user_utt=raw_round[0]["prefix_stripped_text"], api_call_machine=api_call_machine, api_resp_machine=tod.SerializationHelpers.str_to_api_dict( raw_round[2]["prefix_stripped_text"]), sys_utt=raw_round[3]["prefix_stripped_text"], ) rounds.append(r) preempt_round = blob["dialog"][0] if len(preempt_round) != 4: return None for i in range(len(preempt_round)): if "prefix_stripped_text" not in preempt_round[i]: preempt_round[i]["prefix_stripped_text"] = preempt_round[ i].get("text", PREFIXES_PREEMPT[i])[len(PREFIXES_PREEMPT[i]):] episode = tod.TodStructuredEpisode( domain=preempt_round[0].get("domain", ""), api_schemas_machine=tod.SerializationHelpers.str_to_api_schemas( preempt_round[0].get("prefix_stripped_text", "")), goal_calls_machine=tod.SerializationHelpers.str_to_goals( preempt_round[3].get("prefix_stripped_text")), rounds=rounds, ) return episode
def _get_turns_from_parsed(self, user_utt, api_calls, api_resps, sys_utt): assert len(api_calls) == len(api_resps) if len(api_calls) == 0: api_calls = [{}] api_resps = [{}] turns = len(api_calls) user_utts = [SILENCE_TOKEN] * turns user_utts[0] = user_utt sys_utts = [SILENCE_TOKEN] * turns sys_utts[turns - 1] = sys_utt result = [] for i in range(turns): result.append( tod.TodStructuredRound( sys_utt=sys_utts[i], api_call_machine=api_calls[i], api_resp_machine=api_resps[i], user_utt=user_utts[i], )) return result
def _get_round(self, dialogue_id, raw_episode, turn_id): """ Parse to TodStructuredRound. Assume User turn first. """ user_turn = raw_episode[turn_id] if user_turn["speaker"] != "USER": raise RuntimeError( f"Got non-user turn when it should have been in {dialogue_id}; turn id {turn_id}" ) sys_turn = raw_episode[turn_id + 1] sys_dialog_act = self.dialog_acts[dialogue_id][str(turn_id + 1)]["dialog_act"] if sys_turn["speaker"] != "SYSTEM": raise RuntimeError( f"Got non-system turn when it should have been in {dialogue_id}; turn id {turn_id}" ) frames = user_turn.get("frames", []) call = {} resp = {} for frame in frames: if frame.get("state", {}).get("active_intent", "NONE") != "NONE": intent = frame["state"]["active_intent"] domain = frame["service"] maybe_call_raw = copy.deepcopy(frame["state"]["slot_values"]) maybe_call = {} truncate_length = len(domain) + 1 for key in maybe_call_raw: maybe_call[key[truncate_length:]] = maybe_call_raw[key][0] maybe_call[tod.STANDARD_API_NAME_SLOT] = intent if "find" in intent: for key in sys_dialog_act: if "Inform" in key or "NoOffer" in key: # Gotta check to make sure if it's inform, that it's about the right topic if "Inform" in key: valid = True slots = [x[0] for x in sys_dialog_act[key]] for slot in slots: valid &= self._slot_in_schema(slot, intent) | ( slot == "choice" ) if not valid: continue call = maybe_call resp = self._get_find_api_response( intent, frame["state"]["slot_values"], sys_dialog_act ) elif "book" in intent: for key in sys_dialog_act: if "Book" in key: # and "Inform" not in key: resp = {x[0]: x[1] for x in sys_dialog_act[key]} call = maybe_call if call == self.last_call: call = {} resp = {} if len(call) > 0: self.last_call = call return ( call, tod.TodStructuredRound( user_utt=user_turn["utterance"], api_call_machine=call, api_resp_machine=resp, sys_utt=sys_turn["utterance"], ), )
def setup_episodes(self, fold): """ Parses into TodStructuredEpisode. """ domains = self.opt.get("msre2e_domains", DOMAINS) chunks = self._load_data(fold, domains) domains_cnt = Counter() episodes = [] for utterances in chunks: if len(utterances) < 1: continue domain = utterances[0]["domain"] domains_cnt[domain] += 1 idx = 0 rounds = [] goal_calls = [] if len(utterances) > 0 and utterances[0]["speaker"] == "agent": idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( "agent", utterances, idx) r = tod.TodStructuredRound( user_utt=tod.CONST_SILENCE, api_resp_machine=api_resp, sys_utt=sys_utt, ) rounds.append(r) cum_api_call = {} while idx < len(utterances): idx, user_utt, api_call = self._get_utterance_and_api_call_for_speaker( "user", utterances, idx) idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( "agent", utterances, idx) if not self.opt["use_cumulative_api_calls"]: r = tod.TodStructuredRound( user_utt=user_utt, api_call_machine=api_call, api_resp_machine=api_resp, sys_utt=sys_utt, ) else: cum_api_call.update(api_call) r = tod.TodStructuredRound( user_utt=user_utt, api_call_machine=copy.deepcopy(cum_api_call) if len(api_resp) > 0 else {}, api_resp_machine=api_resp if len(api_resp) > 0 else {}, sys_utt=sys_utt, ) rounds.append(r) if len(api_call) > 0: goal_calls.append(api_call) episode = tod.TodStructuredEpisode( domain=domain, api_schemas_machine=SLOT_NAMES[domain], goal_calls_machine=goal_calls, rounds=rounds, delex=self.opt.get("delex", False), ) episodes.append(episode) return episodes
def setup_episodes(self, fold): result = [] domains = self.opt.get("multidogo_domains", DOMAINS) if type(domains) is str: domains = [domains] intent_type = self.opt.get("intent-type", TURN_INTENT) for _conv_id, domain, conversation in self._iterate_over_conversations( domains, intent_type): if len(conversation) == 0 or not (all( ["role" in turn for turn in conversation.values()])): continue rounds = [] prev_role = conversation["0"]["role"] if prev_role == "customer": user_utt = [conversation["0"]["text"]] api_call = conversation["0"].get("slots", {}) api_resp = {} sys_utt = [] else: user_utt = ["__SILENCE__"] api_call = {} api_resp = conversation["0"].get("slots", {}) sys_utt = [conversation["0"]["text"]] all_calls = api_call api_call = {tod.STANDARD_API_NAME_SLOT: domain} for i in range(1, len(conversation)): turn = conversation[str(i)] if prev_role == "agent" and prev_role != turn["role"]: rounds.append( tod.TodStructuredRound( user_utt="\n".join(user_utt), api_call_machine=api_call, api_resp_machine=api_resp, sys_utt="\n".join(sys_utt), )) user_utt = [] api_call = {tod.STANDARD_API_NAME_SLOT: domain} api_resp = {} sys_utt = [] prev_role = turn["role"] slot = turn.get("slots", {}) if prev_role == "customer": user_utt.append(turn["text"]) api_call.update(slot) all_calls.update(slot) else: api_resp.update(slot) sys_utt.append(turn["text"]) rounds.append( tod.TodStructuredRound( user_utt="".join(user_utt), api_call_machine=api_call, api_resp_machine=api_resp, sys_utt="".join(sys_utt), )) goal_calls = copy.deepcopy(all_calls) goal_calls[tod.STANDARD_API_NAME_SLOT] = domain result.append( tod.TodStructuredEpisode( domain=domain, api_schemas_machine=[{ tod.STANDARD_API_NAME_SLOT: domain, tod.STANDARD_OPTIONAL_KEY: all_calls.keys(), }], goal_calls_machine=[goal_calls], rounds=rounds, )) return result