def get_reference_checksum(src_url): *_, filename = src_url.split('/') if filename in RedditConfig().missing_checksums: return RedditConfig().missing_checksums[filename] checksums_url = RedditConfig().submissions_checksum_url_template if filename.startswith('RS') \ else RedditConfig().comments_checksum_url_template r = requests.get(checksums_url) if r.status_code != 200: raise RuntimeError( f"Couldn't get checksums from {checksums_url}, status={r.status_code}" ) checksum = None for line in r.content.decode('utf-8').split('\n'): if filename in line: checksum, *_ = line.split() break if not checksum: raise RuntimeError(f"Couldn't get checksum for {filename}") return checksum
def output(self): return [ luigi.LocalTarget(RedditConfig().make_filtered_filepath( self.date, 'submissions')), luigi.LocalTarget(RedditConfig().make_ids_filepath( self.date, 'submissions')) ]
def run(self): with self.output().temporary_path() as tmp_path: src_url = RedditConfig().make_source_url(self.date, self.filetype) ref_checksum = get_reference_checksum(src_url) r = requests.get(src_url, stream=True) if r.status_code != 200: raise RuntimeError( f"Error downloading {src_url}, status={r.status_code}") m = sha256() f = open(tmp_path, 'wb') for chunk in r.iter_content( chunk_size=RedditConfig().download_chunk_size): if chunk: f.write(chunk) m.update(chunk) f.close() checksum = m.hexdigest() if checksum != ref_checksum: raise RuntimeError( f"Checksums don't match for {'RC' if self.filetype == 'comments' else 'RS'}_{self.date}!" )
def run(self): with self.output().temporary_path() as zip_path: archive = ZipFile(zip_path, 'w', compression=ZIP_DEFLATED) for src in self.input(): if os.stat(src.path).st_size == 0: continue dest = os.path.join('dialogues', *src.path.split(os.sep)[-2:]) archive.write(src.path, arcname=dest) tasks_to_write = [ ('tasks.txt', RedditConfig().all_subreddits), ('tasks_train.txt', RedditConfig().all_subreddits - RedditConfig().held_out_subreddits), ('tasks_held_out.txt', RedditConfig().held_out_subreddits) ] def make_json_for_subreddit(subreddit): return json.dumps({ 'domain': subreddit, 'task_id': subreddit, 'bot_prompt': '', 'user_prompt': '', 'bot_role': '', 'user_role': '', }) for fp, tasks, in tasks_to_write: with TextIOWrapper(archive.open(fp, 'w'), encoding='utf-8') as f: f.write('\n'.join( [make_json_for_subreddit(t) for t in sorted(tasks)]) + '\n') archive.close()
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) random.seed(RedditConfig().random_seed) self._date_in_train = RedditConfig().is_date_in_train(self.date) self._subset_subreddit_combos = list( itertools.product(list(Subset), RedditConfig().all_subreddits))
def run(self): with self.output()[0].temporary_path() as tmp_path, \ self.output()[1].temporary_path() as tmp_ids_path: process_file_linewise(in_filepath=RedditConfig().make_raw_filepath( self.date, 'submissions'), out_filepath=tmp_path, out_ids_filepath=tmp_ids_path, parser=lambda x: json.loads(x), filterer=RawSubmissionFilterer(), outputter=SubmissionJsonOutputter(), buffer_size=RedditConfig().dump_interval)
def run(self): with self.output().temporary_path() as outpath: outfile = open(outpath, 'wt', encoding='utf-8') sources = [ RedditConfig().make_split_date_domain_path( d, self.split, self.subreddit) for d in RedditConfig().make_all_dates() ] for src in sources: with gzip.open(src, 'rt', encoding='utf-8') as infile: shutil.copyfileobj(infile, outfile, RedditConfig().cat_buffer_size) outfile.close()
def run(self): # Open a lot of files with ExitStack() as stack: outfiles = {(s, d): stack.enter_context(t.temporary_path()) for s, d, t in self._make_outputs()} buffers = defaultdict(list) infile = gzip.open(RedditConfig().make_sampled_dialogues_filepath( self.date), 'rt', encoding='utf-8') for line in infile: try: dlg = json.loads(line) except json.DecodeError as e: logging.debug( f"[split] Error parsing line ({str(e)})\n - {line}") continue domain = dlg['domain'] key = None if self._date_in_train and domain in RedditConfig( ).held_out_subreddits: key = (Subset.VALIDATION_DATE_IN_DOMAIN_OUT, domain) elif not self._date_in_train and domain in RedditConfig( ).held_out_subreddits: key = (Subset.VALIDATION_DATE_OUT_DOMAIN_OUT, domain) elif not self._date_in_train and domain in RedditConfig( ).all_subreddits: key = (Subset.VALIDATION_DATE_OUT_DOMAIN_IN, domain) elif domain in RedditConfig().all_subreddits and random.rand( ) < RedditConfig().train_sampling_pr: key = (Subset.TRAINING, domain) elif domain in RedditConfig().all_subreddits: key = (Subset.VALIDATION_DATE_IN_DOMAIN_IN, domain) else: continue # Avoid extra encode by storing the raw lines buffers[key].append(line) if len(buffers[key]) >= RedditConfig().dump_interval: with gzip.open(outfiles[key], 'at', encoding='utf-8') as outfile: outfile.write(''.join(buffers[key])) buffers[key] = [] # Iterate over outfiles to touch/create even empty files else luigi's temp path complains for key, outfilefp in outfiles.items(): with gzip.open(outfiles[key], 'at', encoding='utf-8') as outfile: if key in buffers and len(buffers[key]) > 0: outfile.write(''.join(buffers[key]))
def __init__(self, submission_ids_file=None): self.subreddits = set(RedditConfig().all_subreddits) if submission_ids_file: with open(submission_ids_file, 'r', encoding='utf-8') as f: self.submission_ids = set( [SUBMISSION_ID_PREFIX + x.strip() for x in f.readlines()]) else: self.submission_ids = set()
def on_success(self): if RedditConfig().delete_intermediate_data: delete_requires(self.requires()) parent_reqs = self.requires() if not isinstance(parent_reqs, list): parent_reqs = [parent_reqs] for r in parent_reqs: delete_requires(r.requires())
def download(workers, config, log_level): RedditConfig.initialize(config) print(RedditConfig()) luigi.configuration.get_config().set( 'resources', 'max_concurrent_downloads', str(RedditConfig().max_concurrent_downloads)) result = luigi.interface.build( [ DownloadRawFile(d, ft) for d, ft in RedditConfig().make_all_dates_filetypes() ], workers=workers, local_scheduler=True, log_level=log_level, detailed_summary=True, ) print(result.summary_text)
def generate(workers, config, log_level): RedditConfig.initialize(config) print(RedditConfig()) luigi.configuration.get_config().set( 'resources', 'max_concurrent_downloads', str(RedditConfig().max_concurrent_downloads)) luigi.configuration.get_config().set( 'resources', 'max_concurrent_build', str(RedditConfig().max_concurrent_build)) luigi.configuration.get_config().set( 'resources', 'max_concurrent_sample', str(RedditConfig().max_concurrent_sample)) result = luigi.interface.build( [ZipDataset()], workers=workers, local_scheduler=True, log_level=log_level, detailed_summary=True, ) print(result.summary_text)
def output(self): return luigi.LocalTarget( RedditConfig().make_sampled_dialogues_filepath(self.date))
def turns_all_have_chars(cls, v): if len(v) < RedditConfig().min_dialogue_length: raise ValueError('Not enough turns!') if any([(not t.strip()) for t in v]): raise ValueError('Zero-length strings in turns!') return v
def __init__(self): self.subreddits = set(RedditConfig().all_subreddits)
def _make_outputs(self): # We take care of 0B files later return [(s, d, luigi.LocalTarget(RedditConfig().make_split_date_domain_path( self.date, s, d))) for s, d in self._subset_subreddit_combos]
def output(self): return luigi.LocalTarget(RedditConfig().make_filtered_filepath( self.date, 'comments'))
def requires(self): return [ MergeDialoguesOverDates(s, d) for s, d in itertools.product(list(Subset), RedditConfig().all_subreddits) ]
def output(self): return luigi.LocalTarget(RedditConfig().make_zip_path())
def output(self): dest_fp = RedditConfig().make_raw_filepath(self.date, self.filetype) return luigi.LocalTarget(dest_fp)
def requires(self): return [SplitDialogues(d) for d in RedditConfig().make_all_dates()]
def output(self): return luigi.LocalTarget(RedditConfig().make_split_domain_path( split=self.split, domain=self.subreddit))
def run(self): random.seed(RedditConfig().random_seed) turn_limit, limit_chars, limit_tokens = -1, False, False if RedditConfig().turn_token_limit > 0: turn_limit, limit_tokens = RedditConfig().turn_token_limit, True elif RedditConfig().turn_char_limit > 0: turn_limit, limit_chars = RedditConfig().turn_char_limit, True filterer = SingleDialogueFilterer( turn_limit=turn_limit, tokens=limit_tokens, chars=limit_chars, min_turns=RedditConfig().min_dialogue_length) dlgs = [] submission2subreddit = {} # First we need to read all dialogues # In the future may need to grouping by subreddit or post id to prevent memory errors f = gzip.open(RedditConfig().make_raw_dialogues_filepath(self.date), 'rt', encoding='utf-8') for line in f: if not len(line.strip()): continue try: dlg_obj = json.loads(line) except json.JSONDecodeError as e: logging.debug( f"[sample] Error parsing line ({str(e)})\n - {line}") continue if filterer(dlg_obj) is None: continue dlgs.append(dlg_obj['turns_with_ids']) # Key sub2sub by submission ID submission_id = dlg_obj['turns_with_ids'][0][0] submission2subreddit[submission_id] = dlg_obj['domain'] f.close() # Next sample and write the dialogues with self.output().temporary_path() as tmp_path: f = gzip.open(tmp_path, 'wt', encoding='utf-8') sampler = DialogueGrouperSampler() grouping_configs = [ # First group by post, don't care about shuffling or group limits GrouperSamplerCfg(group_level=0), # Next group by a level comment from the top # With default parameters: group by top comment, choose one dialogue per top comment group, # only 2 top comment groups, with shuffling GrouperSamplerCfg( group_level=RedditConfig().sampling_group_level, n_groups=RedditConfig().sampling_n_groups, n_per_group=RedditConfig().sampling_n_per_group, shuffle_groups=RedditConfig().sampling_shuffle_groups, shuffle_within_groups=RedditConfig( ).sampling_shuffle_within_groups) ] sampled_dlgs = sampler(dlgs, grouping_configs) def to_json(dlg_obj): # Validate SessionItem(**dlg_obj) return json.dumps(dlg_obj) # Can't write all dialogues at once else MemoryError for i in range(0, len(sampled_dlgs), RedditConfig().dump_interval): dlgs_to_write = [ to_json({ 'domain': submission2subreddit[d[0][0]], 'task_id': md5(submission2subreddit[d[0][0]].encode( 'utf-8')).hexdigest()[:8], 'turns': [turn for _, turn in d], # id is the hash of the joined turn ids 'id': md5(('_'.join([tid for tid, t in d ])).encode('utf-8')).hexdigest(), 'bot_id': '', 'user_id': '', }) for d in sampled_dlgs[i:i + RedditConfig().dump_interval] ] f.write('\n'.join(dlgs_to_write) + '\n') f.close() logging.debug( f" > [{self.date}] # DLGS: before sample={len(dlgs)}, after sample={len(sampled_dlgs)}" ) lens = [len(d) for d in sampled_dlgs] logging.debug( f" > [{self.date}] DLG LENGTHS: max={max(lens)}, min={min(lens)}, avg={sum(lens) / len(lens):2.2f}" )
def on_success(self): if RedditConfig().delete_intermediate_data: delete_requires(self.requires())
def run(self): turns = dict() submission2subreddit = {} with gzip.open(RedditConfig().make_filtered_filepath( self.date, 'submissions'), 'rt', encoding='utf-8') as f: for line in f: parsed = json.loads(line) rid, content, subreddit = parsed['id'], parsed['body'], parsed[ 'subreddit'] submission2subreddit[rid] = subreddit turns[rid] = (content, rid, False) with gzip.open(RedditConfig().make_filtered_filepath( self.date, 'comments'), 'rt', encoding='utf-8') as f: for line in f: try: parsed = json.loads(line) except json.JSONDecodeError as e: logging.debug( f"[build] Error parsing line ({str(e)})\n - {line}") continue rid, content, parent_comment_id = parsed['id'], parsed[ 'body'], parsed['parent_id'] # Third element of tuple represents if the turn has a child or not this_content, this_parent, this_has_child = content, parent_comment_id, False if rid in turns: _, _, this_has_child = turns[rid] turns[rid] = (this_content, this_parent, this_has_child) ref_content, ref_parent, ref_has_child = '', '', True if parent_comment_id in turns: ref_content, ref_parent, _ = turns[parent_comment_id] turns[parent_comment_id] = (ref_content, ref_parent, ref_has_child) with self.output().temporary_path() as tmp_path: f = gzip.open(tmp_path, 'wt', encoding='utf-8') dlgs_to_write = [] for rid in turns: # Start building the dialogue from the leaf. Also ignore empty turns (placeholder) content, parent, has_child = turns[rid] if not content or has_child: continue dlg, ids = [], [] while True: dlg.append(content) ids.append(rid) try: rid = parent content, parent, has_child = turns[rid] if not content or len(content.strip()) == 0: dlg = [] except KeyError: dlg = [] finally: if not dlg or not has_child: break if rid == parent: dlg.append(content) ids.append(rid) break # Some validation if not dlg or len(dlg) < RedditConfig( ).min_dialogue_length or not all(t.strip() for t in dlg): continue if not ids or len(ids) != len(dlg) or not all(i.strip() for i in ids): continue try: # Lowercase the subreddit subreddit = submission2subreddit[ids[-1]].strip().lower() except KeyError: continue if not subreddit: continue dlg_obj = { 'domain': subreddit, 'turns_with_ids': list(zip(ids, dlg))[::-1], } dlgs_to_write.append(json.dumps(dlg_obj) + '\n') if len(dlgs_to_write) >= RedditConfig().dump_interval: f.write(''.join(dlgs_to_write)) dlgs_to_write = [] # Flush if dlgs_to_write: f.write(''.join(dlgs_to_write)) f.close()