コード例 #1
0
def get_reference_checksum(src_url):
    *_, filename = src_url.split('/')

    if filename in RedditConfig().missing_checksums:
        return RedditConfig().missing_checksums[filename]

    checksums_url = RedditConfig().submissions_checksum_url_template if filename.startswith('RS') \
        else RedditConfig().comments_checksum_url_template

    r = requests.get(checksums_url)
    if r.status_code != 200:
        raise RuntimeError(
            f"Couldn't get checksums from {checksums_url}, status={r.status_code}"
        )

    checksum = None

    for line in r.content.decode('utf-8').split('\n'):
        if filename in line:
            checksum, *_ = line.split()
            break

    if not checksum:
        raise RuntimeError(f"Couldn't get checksum for {filename}")

    return checksum
コード例 #2
0
 def output(self):
     return [
         luigi.LocalTarget(RedditConfig().make_filtered_filepath(
             self.date, 'submissions')),
         luigi.LocalTarget(RedditConfig().make_ids_filepath(
             self.date, 'submissions'))
     ]
コード例 #3
0
    def run(self):
        with self.output().temporary_path() as tmp_path:
            src_url = RedditConfig().make_source_url(self.date, self.filetype)
            ref_checksum = get_reference_checksum(src_url)

            r = requests.get(src_url, stream=True)
            if r.status_code != 200:
                raise RuntimeError(
                    f"Error downloading {src_url}, status={r.status_code}")

            m = sha256()
            f = open(tmp_path, 'wb')
            for chunk in r.iter_content(
                    chunk_size=RedditConfig().download_chunk_size):
                if chunk:
                    f.write(chunk)
                    m.update(chunk)
            f.close()

            checksum = m.hexdigest()

            if checksum != ref_checksum:
                raise RuntimeError(
                    f"Checksums don't match for {'RC' if self.filetype == 'comments' else 'RS'}_{self.date}!"
                )
コード例 #4
0
    def run(self):
        with self.output().temporary_path() as zip_path:
            archive = ZipFile(zip_path, 'w', compression=ZIP_DEFLATED)
            for src in self.input():
                if os.stat(src.path).st_size == 0:
                    continue
                dest = os.path.join('dialogues', *src.path.split(os.sep)[-2:])
                archive.write(src.path, arcname=dest)

            tasks_to_write = [
                ('tasks.txt', RedditConfig().all_subreddits),
                ('tasks_train.txt', RedditConfig().all_subreddits -
                 RedditConfig().held_out_subreddits),
                ('tasks_held_out.txt', RedditConfig().held_out_subreddits)
            ]

            def make_json_for_subreddit(subreddit):
                return json.dumps({
                    'domain': subreddit,
                    'task_id': subreddit,
                    'bot_prompt': '',
                    'user_prompt': '',
                    'bot_role': '',
                    'user_role': '',
                })

            for fp, tasks, in tasks_to_write:
                with TextIOWrapper(archive.open(fp, 'w'),
                                   encoding='utf-8') as f:
                    f.write('\n'.join(
                        [make_json_for_subreddit(t)
                         for t in sorted(tasks)]) + '\n')

            archive.close()
コード例 #5
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        random.seed(RedditConfig().random_seed)
        self._date_in_train = RedditConfig().is_date_in_train(self.date)

        self._subset_subreddit_combos = list(
            itertools.product(list(Subset),
                              RedditConfig().all_subreddits))
コード例 #6
0
 def run(self):
     with self.output()[0].temporary_path() as tmp_path, \
             self.output()[1].temporary_path() as tmp_ids_path:
         process_file_linewise(in_filepath=RedditConfig().make_raw_filepath(
             self.date, 'submissions'),
                               out_filepath=tmp_path,
                               out_ids_filepath=tmp_ids_path,
                               parser=lambda x: json.loads(x),
                               filterer=RawSubmissionFilterer(),
                               outputter=SubmissionJsonOutputter(),
                               buffer_size=RedditConfig().dump_interval)
コード例 #7
0
 def run(self):
     with self.output().temporary_path() as outpath:
         outfile = open(outpath, 'wt', encoding='utf-8')
         sources = [
             RedditConfig().make_split_date_domain_path(
                 d, self.split, self.subreddit)
             for d in RedditConfig().make_all_dates()
         ]
         for src in sources:
             with gzip.open(src, 'rt', encoding='utf-8') as infile:
                 shutil.copyfileobj(infile, outfile,
                                    RedditConfig().cat_buffer_size)
         outfile.close()
コード例 #8
0
    def run(self):
        # Open a lot of files
        with ExitStack() as stack:
            outfiles = {(s, d): stack.enter_context(t.temporary_path())
                        for s, d, t in self._make_outputs()}

            buffers = defaultdict(list)
            infile = gzip.open(RedditConfig().make_sampled_dialogues_filepath(
                self.date),
                               'rt',
                               encoding='utf-8')

            for line in infile:
                try:
                    dlg = json.loads(line)
                except json.DecodeError as e:
                    logging.debug(
                        f"[split] Error parsing line ({str(e)})\n - {line}")
                    continue

                domain = dlg['domain']
                key = None

                if self._date_in_train and domain in RedditConfig(
                ).held_out_subreddits:
                    key = (Subset.VALIDATION_DATE_IN_DOMAIN_OUT, domain)

                elif not self._date_in_train and domain in RedditConfig(
                ).held_out_subreddits:
                    key = (Subset.VALIDATION_DATE_OUT_DOMAIN_OUT, domain)

                elif not self._date_in_train and domain in RedditConfig(
                ).all_subreddits:
                    key = (Subset.VALIDATION_DATE_OUT_DOMAIN_IN, domain)

                elif domain in RedditConfig().all_subreddits and random.rand(
                ) < RedditConfig().train_sampling_pr:
                    key = (Subset.TRAINING, domain)

                elif domain in RedditConfig().all_subreddits:
                    key = (Subset.VALIDATION_DATE_IN_DOMAIN_IN, domain)

                else:
                    continue

                # Avoid extra encode by storing the raw lines
                buffers[key].append(line)

                if len(buffers[key]) >= RedditConfig().dump_interval:
                    with gzip.open(outfiles[key], 'at',
                                   encoding='utf-8') as outfile:
                        outfile.write(''.join(buffers[key]))
                    buffers[key] = []

            # Iterate over outfiles to touch/create even empty files else luigi's temp path complains
            for key, outfilefp in outfiles.items():
                with gzip.open(outfiles[key], 'at',
                               encoding='utf-8') as outfile:
                    if key in buffers and len(buffers[key]) > 0:
                        outfile.write(''.join(buffers[key]))
コード例 #9
0
    def __init__(self, submission_ids_file=None):
        self.subreddits = set(RedditConfig().all_subreddits)

        if submission_ids_file:
            with open(submission_ids_file, 'r', encoding='utf-8') as f:
                self.submission_ids = set(
                    [SUBMISSION_ID_PREFIX + x.strip() for x in f.readlines()])
        else:
            self.submission_ids = set()
コード例 #10
0
    def on_success(self):
        if RedditConfig().delete_intermediate_data:
            delete_requires(self.requires())

        parent_reqs = self.requires()
        if not isinstance(parent_reqs, list):
            parent_reqs = [parent_reqs]

        for r in parent_reqs:
            delete_requires(r.requires())
コード例 #11
0
def download(workers, config, log_level):
    RedditConfig.initialize(config)
    print(RedditConfig())

    luigi.configuration.get_config().set(
        'resources', 'max_concurrent_downloads',
        str(RedditConfig().max_concurrent_downloads))

    result = luigi.interface.build(
        [
            DownloadRawFile(d, ft)
            for d, ft in RedditConfig().make_all_dates_filetypes()
        ],
        workers=workers,
        local_scheduler=True,
        log_level=log_level,
        detailed_summary=True,
    )
    print(result.summary_text)
コード例 #12
0
def generate(workers, config, log_level):
    RedditConfig.initialize(config)
    print(RedditConfig())

    luigi.configuration.get_config().set(
        'resources', 'max_concurrent_downloads',
        str(RedditConfig().max_concurrent_downloads))
    luigi.configuration.get_config().set(
        'resources', 'max_concurrent_build',
        str(RedditConfig().max_concurrent_build))
    luigi.configuration.get_config().set(
        'resources', 'max_concurrent_sample',
        str(RedditConfig().max_concurrent_sample))

    result = luigi.interface.build(
        [ZipDataset()],
        workers=workers,
        local_scheduler=True,
        log_level=log_level,
        detailed_summary=True,
    )
    print(result.summary_text)
コード例 #13
0
 def output(self):
     return luigi.LocalTarget(
         RedditConfig().make_sampled_dialogues_filepath(self.date))
コード例 #14
0
 def turns_all_have_chars(cls, v):
   if len(v) < RedditConfig().min_dialogue_length:
     raise ValueError('Not enough turns!')
   if any([(not t.strip()) for t in v]):
     raise ValueError('Zero-length strings in turns!')
   return v
コード例 #15
0
 def __init__(self):
     self.subreddits = set(RedditConfig().all_subreddits)
コード例 #16
0
 def _make_outputs(self):
     # We take care of 0B files later
     return [(s, d,
              luigi.LocalTarget(RedditConfig().make_split_date_domain_path(
                  self.date, s, d)))
             for s, d in self._subset_subreddit_combos]
コード例 #17
0
 def output(self):
     return luigi.LocalTarget(RedditConfig().make_filtered_filepath(
         self.date, 'comments'))
コード例 #18
0
 def requires(self):
     return [
         MergeDialoguesOverDates(s, d)
         for s, d in itertools.product(list(Subset),
                                       RedditConfig().all_subreddits)
     ]
コード例 #19
0
 def output(self):
     return luigi.LocalTarget(RedditConfig().make_zip_path())
コード例 #20
0
 def output(self):
     dest_fp = RedditConfig().make_raw_filepath(self.date, self.filetype)
     return luigi.LocalTarget(dest_fp)
コード例 #21
0
 def requires(self):
     return [SplitDialogues(d) for d in RedditConfig().make_all_dates()]
コード例 #22
0
 def output(self):
     return luigi.LocalTarget(RedditConfig().make_split_domain_path(
         split=self.split, domain=self.subreddit))
コード例 #23
0
    def run(self):
        random.seed(RedditConfig().random_seed)

        turn_limit, limit_chars, limit_tokens = -1, False, False

        if RedditConfig().turn_token_limit > 0:
            turn_limit, limit_tokens = RedditConfig().turn_token_limit, True

        elif RedditConfig().turn_char_limit > 0:
            turn_limit, limit_chars = RedditConfig().turn_char_limit, True

        filterer = SingleDialogueFilterer(
            turn_limit=turn_limit,
            tokens=limit_tokens,
            chars=limit_chars,
            min_turns=RedditConfig().min_dialogue_length)

        dlgs = []
        submission2subreddit = {}

        # First we need to read all dialogues
        # In the future may need to grouping by subreddit or post id to prevent memory errors
        f = gzip.open(RedditConfig().make_raw_dialogues_filepath(self.date),
                      'rt',
                      encoding='utf-8')
        for line in f:
            if not len(line.strip()):
                continue

            try:
                dlg_obj = json.loads(line)
            except json.JSONDecodeError as e:
                logging.debug(
                    f"[sample] Error parsing line ({str(e)})\n - {line}")
                continue

            if filterer(dlg_obj) is None:
                continue

            dlgs.append(dlg_obj['turns_with_ids'])
            # Key sub2sub by submission ID
            submission_id = dlg_obj['turns_with_ids'][0][0]
            submission2subreddit[submission_id] = dlg_obj['domain']
        f.close()

        # Next sample and write the dialogues
        with self.output().temporary_path() as tmp_path:
            f = gzip.open(tmp_path, 'wt', encoding='utf-8')

            sampler = DialogueGrouperSampler()
            grouping_configs = [
                # First group by post, don't care about shuffling or group limits
                GrouperSamplerCfg(group_level=0),
                # Next group by a level comment from the top
                # With default parameters: group by top comment, choose one dialogue per top comment group,
                # only 2 top comment groups, with shuffling
                GrouperSamplerCfg(
                    group_level=RedditConfig().sampling_group_level,
                    n_groups=RedditConfig().sampling_n_groups,
                    n_per_group=RedditConfig().sampling_n_per_group,
                    shuffle_groups=RedditConfig().sampling_shuffle_groups,
                    shuffle_within_groups=RedditConfig(
                    ).sampling_shuffle_within_groups)
            ]
            sampled_dlgs = sampler(dlgs, grouping_configs)

            def to_json(dlg_obj):
                # Validate
                SessionItem(**dlg_obj)
                return json.dumps(dlg_obj)

            # Can't write all dialogues at once else MemoryError
            for i in range(0, len(sampled_dlgs), RedditConfig().dump_interval):
                dlgs_to_write = [
                    to_json({
                        'domain':
                        submission2subreddit[d[0][0]],
                        'task_id':
                        md5(submission2subreddit[d[0][0]].encode(
                            'utf-8')).hexdigest()[:8],
                        'turns': [turn for _, turn in d],
                        # id is the hash of the joined turn ids
                        'id':
                        md5(('_'.join([tid for tid, t in d
                                       ])).encode('utf-8')).hexdigest(),
                        'bot_id':
                        '',
                        'user_id':
                        '',
                    })
                    for d in sampled_dlgs[i:i + RedditConfig().dump_interval]
                ]
                f.write('\n'.join(dlgs_to_write) + '\n')

            f.close()

            logging.debug(
                f" > [{self.date}] # DLGS: before sample={len(dlgs)}, after sample={len(sampled_dlgs)}"
            )
            lens = [len(d) for d in sampled_dlgs]
            logging.debug(
                f" > [{self.date}] DLG LENGTHS: max={max(lens)}, min={min(lens)}, avg={sum(lens) / len(lens):2.2f}"
            )
コード例 #24
0
 def on_success(self):
     if RedditConfig().delete_intermediate_data:
         delete_requires(self.requires())
コード例 #25
0
    def run(self):
        turns = dict()
        submission2subreddit = {}

        with gzip.open(RedditConfig().make_filtered_filepath(
                self.date, 'submissions'),
                       'rt',
                       encoding='utf-8') as f:
            for line in f:
                parsed = json.loads(line)
                rid, content, subreddit = parsed['id'], parsed['body'], parsed[
                    'subreddit']
                submission2subreddit[rid] = subreddit
                turns[rid] = (content, rid, False)

        with gzip.open(RedditConfig().make_filtered_filepath(
                self.date, 'comments'),
                       'rt',
                       encoding='utf-8') as f:
            for line in f:
                try:
                    parsed = json.loads(line)
                except json.JSONDecodeError as e:
                    logging.debug(
                        f"[build] Error parsing line ({str(e)})\n - {line}")
                    continue

                rid, content, parent_comment_id = parsed['id'], parsed[
                    'body'], parsed['parent_id']

                # Third element of tuple represents if the turn has a child or not
                this_content, this_parent, this_has_child = content, parent_comment_id, False
                if rid in turns:
                    _, _, this_has_child = turns[rid]
                turns[rid] = (this_content, this_parent, this_has_child)

                ref_content, ref_parent, ref_has_child = '', '', True
                if parent_comment_id in turns:
                    ref_content, ref_parent, _ = turns[parent_comment_id]
                turns[parent_comment_id] = (ref_content, ref_parent,
                                            ref_has_child)

        with self.output().temporary_path() as tmp_path:
            f = gzip.open(tmp_path, 'wt', encoding='utf-8')
            dlgs_to_write = []

            for rid in turns:
                # Start building the dialogue from the leaf. Also ignore empty turns (placeholder)
                content, parent, has_child = turns[rid]
                if not content or has_child:
                    continue

                dlg, ids = [], []
                while True:
                    dlg.append(content)
                    ids.append(rid)
                    try:
                        rid = parent
                        content, parent, has_child = turns[rid]
                        if not content or len(content.strip()) == 0:
                            dlg = []
                    except KeyError:
                        dlg = []
                    finally:
                        if not dlg or not has_child:
                            break
                        if rid == parent:
                            dlg.append(content)
                            ids.append(rid)
                            break

                # Some validation
                if not dlg or len(dlg) < RedditConfig(
                ).min_dialogue_length or not all(t.strip() for t in dlg):
                    continue

                if not ids or len(ids) != len(dlg) or not all(i.strip()
                                                              for i in ids):
                    continue

                try:
                    # Lowercase the subreddit
                    subreddit = submission2subreddit[ids[-1]].strip().lower()
                except KeyError:
                    continue

                if not subreddit:
                    continue

                dlg_obj = {
                    'domain': subreddit,
                    'turns_with_ids': list(zip(ids, dlg))[::-1],
                }

                dlgs_to_write.append(json.dumps(dlg_obj) + '\n')

                if len(dlgs_to_write) >= RedditConfig().dump_interval:
                    f.write(''.join(dlgs_to_write))
                    dlgs_to_write = []

            # Flush
            if dlgs_to_write:
                f.write(''.join(dlgs_to_write))
            f.close()