Beispiel #1
0
    def setup_episodes(self, fold):
        """
        Parses into TodStructuredEpisode.
        """
        domains = self.opt.get("taskmaster2_domains", DOMAINS)
        chunks, ontologies = self._load_data(fold, domains)
        domains_cnt = Counter()
        episodes = []
        for _, row in chunks.iterrows():
            domains_cnt[row["domain"]] += 1
            utterances = row["utterances"][:]

            idx = 0
            rounds = []
            goal_calls = []
            if len(utterances) > 0 and utterances[0]["speaker"] == "ASSISTANT":
                idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker(
                    "ASSISTANT", utterances, idx)
                r = tod.TodStructuredRound(api_resp_machine=api_resp,
                                           sys_utt=sys_utt)
                rounds.append(r)

            cum_api_call = {}
            while idx < len(utterances):
                idx, user_utt, api_call = self._get_utterance_and_api_call_for_speaker(
                    "USER", utterances, idx)
                idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker(
                    "ASSISTANT", utterances, idx)
                if not self.opt["use_cumulative_api_calls"]:
                    r = tod.TodStructuredRound(
                        user_utt=user_utt,
                        api_call_machine=api_call,
                        api_resp_machine=api_resp,
                        sys_utt=sys_utt,
                    )
                else:
                    cum_api_call = self.process_call_for_cumlative_standalone_api(
                        api_call, cum_api_call)
                    r = tod.TodStructuredRound(
                        user_utt=user_utt,
                        api_call_machine=cum_api_call
                        if len(api_resp) > 0 else {},
                        api_resp_machine=api_resp if len(api_resp) > 0 else {},
                        sys_utt=sys_utt,
                    )

                rounds.append(r)
                if len(api_call) > 0:
                    goal_calls.append(api_call)

            episode = tod.TodStructuredEpisode(
                domain=tod.SerializationHelpers.inner_list_join(row["domain"]),
                api_schemas_machine=self._get_onto_list(
                    ontologies, row["domain"]),
                goal_calls_machine=goal_calls,
                rounds=rounds,
                delex=self.opt.get("delex", False),
            )
            episodes.append(episode)
        return episodes
Beispiel #2
0
 def setup_episodes(self, fold):
     """
     Parses Google SGD episodes into TodStructuredEpisode.
     """
     schema_lookup, dialogues = self._load_data(fold)
     result = []
     for dialogue in dialogues:
         domains = {s.split("_")[0].strip() for s in dialogue["services"]}
         turns = dialogue["turns"]
         rounds = []
         for turn_id in range(0, len(turns), 2):
             user_turn = turns[turn_id]
             sys_turn = turns[turn_id + 1]
             api_call, api_results = self._get_api_call_and_results(
                 sys_turn)
             r = tod.TodStructuredRound(
                 user_utt=user_turn["utterance"],
                 api_call_machine=api_call,
                 api_resp_machine=api_results,
                 sys_utt=sys_turn["utterance"],
             )
             rounds.append(r)
         # Now that we've got the rounds, make the episode
         episode = tod.TodStructuredEpisode(
             domain=SerializationHelpers.inner_list_join(domains),
             api_schemas_machine=self._get_intent_groundinging(
                 schema_lookup, set(dialogue["services"])),
             goal_calls_machine=self._get_all_service_calls(turns),
             rounds=rounds,
             delex=self.opt.get("delex"),
             extras={"dialogue_id": dialogue["dialogue_id"]},
         )
         result.append(episode)
     # check if the number of episodes should be limited and truncate as required
     return result
Beispiel #3
0
 def _get_turns_from_parsed(self, user_utt, api_calls, api_resps, sys_utt):
     result = [
         tod.TodStructuredRound(
             user_utt=user_utt,
             api_call_machine=api_calls,
             api_resp_machine=api_resps,
             sys_utt=sys_utt,
         )
     ]
     return result
Beispiel #4
0
def get_rounds(episode_idx, max_rounds, use_broken_mock_api_calls=False):
    return [
        tod_core.TodStructuredRound(
            user_utt=f"user_utt_{episode_idx}_{round_idx}",
            api_call_machine=make_api_call_machine(round_idx, episode_idx,
                                                   use_broken_mock_api_calls),
            api_resp_machine=make_api_resp_machine(round_idx),
            sys_utt=f"sys_utt_{episode_idx}_{round_idx}",
        ) for round_idx in range(max_rounds)
    ]
Beispiel #5
0
    def _process_line(self, line):
        blob = json.loads(line)
        if "dialog" not in blob or len(blob["dialog"]) < 1:
            return None
        rounds = []
        for raw_round in blob["dialog"][1:]:
            if "prefix_stripped_text" not in raw_round[0]:
                for i in range(len(raw_round)):
                    if (PREFIXES[i] not in raw_round[i]['text']
                            and self.opt["fail_hard"]):
                        raise RuntimeError(
                            f"Missing prefix `{PREFIXES[i]}` before turn of text: `{raw_round[i]}`"
                        )
                    raw_round[i]["prefix_stripped_text"] = raw_round[i].get(
                        "text", PREFIXES[i])[len(PREFIXES[i]):]
            if len(raw_round) != 4:
                if raw_round[0]["prefix_stripped_text"] != tod.STANDARD_DONE:
                    return None  # misformatted convo, don't learn this.
                break  # TodStructuredEpisode will add in [DONE]
            api_call_machine = tod.SerializationHelpers.str_to_api_dict(
                raw_round[1]["prefix_stripped_text"])
            if (len(api_call_machine) > 0
                    and tod.STANDARD_API_NAME_SLOT not in api_call_machine):
                raise RuntimeError(
                    f"Trying to make API call without `{tod.STANDARD_API_NAME_SLOT}`. Call is: `{raw_round[1]['text']}`"
                )
            r = tod.TodStructuredRound(
                user_utt=raw_round[0]["prefix_stripped_text"],
                api_call_machine=api_call_machine,
                api_resp_machine=tod.SerializationHelpers.str_to_api_dict(
                    raw_round[2]["prefix_stripped_text"]),
                sys_utt=raw_round[3]["prefix_stripped_text"],
            )
            rounds.append(r)
        preempt_round = blob["dialog"][0]
        if len(preempt_round) != 4:
            return None
        for i in range(len(preempt_round)):
            if "prefix_stripped_text" not in preempt_round[i]:
                preempt_round[i]["prefix_stripped_text"] = preempt_round[
                    i].get("text",
                           PREFIXES_PREEMPT[i])[len(PREFIXES_PREEMPT[i]):]

        episode = tod.TodStructuredEpisode(
            domain=preempt_round[0].get("domain", ""),
            api_schemas_machine=tod.SerializationHelpers.str_to_api_schemas(
                preempt_round[0].get("prefix_stripped_text", "")),
            goal_calls_machine=tod.SerializationHelpers.str_to_goals(
                preempt_round[3].get("prefix_stripped_text")),
            rounds=rounds,
        )
        return episode
Beispiel #6
0
    def _get_turns_from_parsed(self, user_utt, api_calls, api_resps, sys_utt):
        assert len(api_calls) == len(api_resps)
        if len(api_calls) == 0:
            api_calls = [{}]
            api_resps = [{}]
        turns = len(api_calls)
        user_utts = [SILENCE_TOKEN] * turns
        user_utts[0] = user_utt
        sys_utts = [SILENCE_TOKEN] * turns
        sys_utts[turns - 1] = sys_utt

        result = []
        for i in range(turns):
            result.append(
                tod.TodStructuredRound(
                    sys_utt=sys_utts[i],
                    api_call_machine=api_calls[i],
                    api_resp_machine=api_resps[i],
                    user_utt=user_utts[i],
                ))
        return result
Beispiel #7
0
    def _get_round(self, dialogue_id, raw_episode, turn_id):
        """
        Parse to TodStructuredRound.

        Assume User turn first.
        """
        user_turn = raw_episode[turn_id]
        if user_turn["speaker"] != "USER":
            raise RuntimeError(
                f"Got non-user turn when it should have been in {dialogue_id}; turn id {turn_id}"
            )
        sys_turn = raw_episode[turn_id + 1]
        sys_dialog_act = self.dialog_acts[dialogue_id][str(turn_id + 1)]["dialog_act"]
        if sys_turn["speaker"] != "SYSTEM":
            raise RuntimeError(
                f"Got non-system turn when it should have been in {dialogue_id}; turn id {turn_id}"
            )
        frames = user_turn.get("frames", [])
        call = {}
        resp = {}
        for frame in frames:
            if frame.get("state", {}).get("active_intent", "NONE") != "NONE":
                intent = frame["state"]["active_intent"]
                domain = frame["service"]
                maybe_call_raw = copy.deepcopy(frame["state"]["slot_values"])
                maybe_call = {}
                truncate_length = len(domain) + 1
                for key in maybe_call_raw:
                    maybe_call[key[truncate_length:]] = maybe_call_raw[key][0]
                maybe_call[tod.STANDARD_API_NAME_SLOT] = intent
                if "find" in intent:
                    for key in sys_dialog_act:
                        if "Inform" in key or "NoOffer" in key:
                            # Gotta check to make sure if it's inform, that it's about the right topic
                            if "Inform" in key:
                                valid = True
                                slots = [x[0] for x in sys_dialog_act[key]]
                                for slot in slots:
                                    valid &= self._slot_in_schema(slot, intent) | (
                                        slot == "choice"
                                    )
                                if not valid:
                                    continue
                            call = maybe_call
                            resp = self._get_find_api_response(
                                intent, frame["state"]["slot_values"], sys_dialog_act
                            )
                elif "book" in intent:
                    for key in sys_dialog_act:
                        if "Book" in key:  # and "Inform" not in key:
                            resp = {x[0]: x[1] for x in sys_dialog_act[key]}
                            call = maybe_call
        if call == self.last_call:
            call = {}
            resp = {}
        if len(call) > 0:
            self.last_call = call
        return (
            call,
            tod.TodStructuredRound(
                user_utt=user_turn["utterance"],
                api_call_machine=call,
                api_resp_machine=resp,
                sys_utt=sys_turn["utterance"],
            ),
        )
Beispiel #8
0
    def setup_episodes(self, fold):
        """
        Parses into TodStructuredEpisode.
        """
        domains = self.opt.get("msre2e_domains", DOMAINS)
        chunks = self._load_data(fold, domains)
        domains_cnt = Counter()
        episodes = []
        for utterances in chunks:
            if len(utterances) < 1:
                continue
            domain = utterances[0]["domain"]
            domains_cnt[domain] += 1
            idx = 0
            rounds = []
            goal_calls = []
            if len(utterances) > 0 and utterances[0]["speaker"] == "agent":
                idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker(
                    "agent", utterances, idx)
                r = tod.TodStructuredRound(
                    user_utt=tod.CONST_SILENCE,
                    api_resp_machine=api_resp,
                    sys_utt=sys_utt,
                )
                rounds.append(r)

            cum_api_call = {}
            while idx < len(utterances):
                idx, user_utt, api_call = self._get_utterance_and_api_call_for_speaker(
                    "user", utterances, idx)
                idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker(
                    "agent", utterances, idx)
                if not self.opt["use_cumulative_api_calls"]:
                    r = tod.TodStructuredRound(
                        user_utt=user_utt,
                        api_call_machine=api_call,
                        api_resp_machine=api_resp,
                        sys_utt=sys_utt,
                    )
                else:
                    cum_api_call.update(api_call)
                    r = tod.TodStructuredRound(
                        user_utt=user_utt,
                        api_call_machine=copy.deepcopy(cum_api_call)
                        if len(api_resp) > 0 else {},
                        api_resp_machine=api_resp if len(api_resp) > 0 else {},
                        sys_utt=sys_utt,
                    )

                rounds.append(r)
                if len(api_call) > 0:
                    goal_calls.append(api_call)

            episode = tod.TodStructuredEpisode(
                domain=domain,
                api_schemas_machine=SLOT_NAMES[domain],
                goal_calls_machine=goal_calls,
                rounds=rounds,
                delex=self.opt.get("delex", False),
            )
            episodes.append(episode)
        return episodes
Beispiel #9
0
    def setup_episodes(self, fold):
        result = []
        domains = self.opt.get("multidogo_domains", DOMAINS)
        if type(domains) is str:
            domains = [domains]
        intent_type = self.opt.get("intent-type", TURN_INTENT)
        for _conv_id, domain, conversation in self._iterate_over_conversations(
                domains, intent_type):
            if len(conversation) == 0 or not (all(
                ["role" in turn for turn in conversation.values()])):
                continue
            rounds = []
            prev_role = conversation["0"]["role"]
            if prev_role == "customer":
                user_utt = [conversation["0"]["text"]]
                api_call = conversation["0"].get("slots", {})
                api_resp = {}
                sys_utt = []
            else:
                user_utt = ["__SILENCE__"]
                api_call = {}
                api_resp = conversation["0"].get("slots", {})
                sys_utt = [conversation["0"]["text"]]
            all_calls = api_call
            api_call = {tod.STANDARD_API_NAME_SLOT: domain}
            for i in range(1, len(conversation)):
                turn = conversation[str(i)]
                if prev_role == "agent" and prev_role != turn["role"]:
                    rounds.append(
                        tod.TodStructuredRound(
                            user_utt="\n".join(user_utt),
                            api_call_machine=api_call,
                            api_resp_machine=api_resp,
                            sys_utt="\n".join(sys_utt),
                        ))
                    user_utt = []
                    api_call = {tod.STANDARD_API_NAME_SLOT: domain}
                    api_resp = {}
                    sys_utt = []
                prev_role = turn["role"]
                slot = turn.get("slots", {})
                if prev_role == "customer":
                    user_utt.append(turn["text"])
                    api_call.update(slot)
                    all_calls.update(slot)
                else:
                    api_resp.update(slot)
                    sys_utt.append(turn["text"])

            rounds.append(
                tod.TodStructuredRound(
                    user_utt="".join(user_utt),
                    api_call_machine=api_call,
                    api_resp_machine=api_resp,
                    sys_utt="".join(sys_utt),
                ))
            goal_calls = copy.deepcopy(all_calls)
            goal_calls[tod.STANDARD_API_NAME_SLOT] = domain
            result.append(
                tod.TodStructuredEpisode(
                    domain=domain,
                    api_schemas_machine=[{
                        tod.STANDARD_API_NAME_SLOT:
                        domain,
                        tod.STANDARD_OPTIONAL_KEY:
                        all_calls.keys(),
                    }],
                    goal_calls_machine=[goal_calls],
                    rounds=rounds,
                ))
        return result