示例#1
0
    def load_data(self, datapath):
        folder, fold = os.path.split(datapath)
        with PathManager.open(os.path.join(folder, "tasks.txt")) as taskf:
            tasks_table = pd.read_json(taskf, lines=True)

        dfolder = os.path.join(folder, "dialogues")

        data = []

        for filename in PathManager.ls(dfolder):
            domain = filename.replace(".txt", "")
            if (self.opt["metalwoz_domains"]
                    and domain not in self.opt["metalwoz_domains"]):
                continue
            fullfn = os.path.join(dfolder, filename)
            with PathManager.open(fullfn) as dataf:
                lines = pd.read_json(dataf, lines=True)
                lines = lines.merge(tasks_table, on="task_id")
                data.append(lines.to_dict("records"))

        # Quick check to make sure we didn't fat-finger the spelling of some domain
        if self.opt["metalwoz_domains"]:
            assert len(data) == len(self.opt["metalwoz_domains"])

        if "test" in self.fold:
            flat = []
            for domain in data:
                flat.extend(domain)
            return flat

        return DatatypeHelper.split_subset_data_by_fold(
            self.fold, data, 0.8, 0.1, 0.1)
示例#2
0
 def load_chunks(self, fold):
     if fold == "valid":
         fold = "dev"  # change name to match file structure
     for path in PathManager.ls(os.path.join(self.dpath, fold)):
         with PathManager.open(os.path.join(self.dpath, fold, path)) as f:
             blob = json.load(f)
             for convo in blob:
                 yield convo
示例#3
0
    def setup_data(self, datapath):
        folder, fold = os.path.split(datapath)
        with PathManager.open(os.path.join(folder, 'tasks.txt')) as taskf:
            tasks_table = pd.read_json(taskf, lines=True)

        dfolder = os.path.join(folder, 'dialogues')

        data = []

        for filename in PathManager.ls(dfolder):
            fullfn = os.path.join(dfolder, filename)
            with PathManager.open(fullfn) as dataf:
                data.append(pd.read_json(dataf, lines=True))

        data = pd.concat(data, axis=0)
        data = data.sample(frac=1.0,
                           random_state=83741)  # metal in l33t numbers, lol
        data = data.merge(tasks_table, on='task_id')
        data['fold'] = data['domain_x'].apply(self._hash)

        for _, row in data.iterrows():
            if fold == 'valid' and row['fold'] != 9:
                continue
            if fold == 'train' and row['fold'] == 9:
                continue
            texts = [row['bot_role']] + list(row['turns'])
            prompts, labels = texts[::2], texts[1::2]
            for i, (prompt, label) in enumerate(zip(prompts, labels)):
                yield {
                    'text': prompt,
                    'label': label,
                    'bot_role': row['bot_role'],
                    'bot_prompt': row['bot_prompt'],
                    'user_role': row['user_role'],
                    'user_prompt': row['user_prompt'],
                    'utterance_id': row['id'],
                    'domain': row['domain_x'],
                    'task_id': row['task_id'],
                }, i == 0