def _build_pcollection( self, unused_pipeline, split, page_content, hashed_url_predicate): beam = tfds.core.lazy_imports.apache_beam def _emit_examples(el): c4_utils.get_counter_inc_fn(split)("examples") _, features = el return features["url"], { "url": features["url"], "text": features["text"], "content-type": features["content-type"], "content-length": features["content-length"], "timestamp": features["timestamp"] } return (page_content | beam.Filter( c4_utils.get_hashed_url_filter_fn(hashed_url_predicate)) | beam.Map(_emit_examples))
def _split_generators(self, dl_manager, pipeline): dl_manager.download_checksums(_CHECKSUMS_URL) # We will automatically download the first default CC version, but others # need to be manually downloaded. cc_versions = set(self.builder_config.cc_versions) default_version = set([DEFAULT_CC_VERSION]) auto_cc_versions = cc_versions & default_version manual_cc_versions = cc_versions - default_version files_to_download = {} files_to_download["wet_path_urls"] = [ _WET_PATH_URL.format(cc_version=cc_version) for cc_version in auto_cc_versions] files_to_download["manual_wet_paths"] = { cc_version: _WET_PATH_URL.format(cc_version=cc_version) for cc_version in manual_cc_versions } if self.builder_config.badwords_filter: files_to_download["badwords"] = { lang: _BADWORDS_URL.format(lang=lang) for lang in _BADWORDS_LANGS if lang != "en" } # Use older "en" file for reproducibility of the original C4. files_to_download["badwords"]["en"] = _EN_BADWORDS_URL if self.builder_config.realnewslike: files_to_download["realnews_domains"] = _REALNEWS_DOMAINS_URL file_paths = dl_manager.download_and_extract(files_to_download) if self.builder_config.webtextlike: owt_path = os.path.join(dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP) if not tf.io.gfile.exists(owt_path): raise AssertionError( "For the WebText-like config, you must manually download the " "following file from {0} and place it in {1}: {2}".format( _OPENWEBTEXT_URLS_URL, dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP)) file_paths["openwebtext_urls_zip"] = dl_manager.extract(owt_path) wet_urls = [] for wet_path_url in file_paths["wet_path_urls"]: with tf.io.gfile.GFile(wet_path_url) as f: wet_urls.extend(["%s/%s" % (_DOWNLOAD_HOST, l.strip()) for l in f]) if dl_manager.register_checksums: # Download locally to register checksums. file_paths.update(dl_manager.download({"wet_files": wet_urls})) else: # Download on the beam workers. file_paths["wet_urls"] = wet_urls file_paths["wet_files"] = [] for cc_version, wet_path_url in file_paths["manual_wet_paths"].items(): crawl_dir = os.path.join( dl_manager.manual_dir, "crawl-data", f"CC-MAIN-{cc_version}") if not tf.io.gfile.exists(crawl_dir): raise AssertionError( "For the non-default Common Crawl version {0}, you must manually " "download the WET files to the directory {1}.".format( cc_version, crawl_dir)) with tf.io.gfile.GFile(wet_path_url) as f: wet_files = [ os.path.join(dl_manager.manual_dir, line.strip()) for line in f if line.strip() not in _KNOWN_CORRUPT_WET_FILES ] logging.info( "Adding %d WET files for manually downloaded version %s.", len(wet_files), cc_version) file_paths["wet_files"].extend(wet_files) page_content_pcollection = self._get_page_content( pipeline, file_paths, dl_manager, self.builder_config.languages) def _lang_filter(url_and_page, lang): _, page = url_and_page return page["language"] == lang def _filter(url_and_page, lang, predicate_fn): return (_lang_filter(url_and_page, lang) and c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page)) train_predicate_fn = lambda x: x % 1000 != 0 # 99.9% validation_predicate_fn = lambda x: x % 1000 == 0 # 00.1% if len(self.builder_config.languages) == 1: # Single-language version. return [ tfds.core.SplitGenerator( name=tfds.Split.TRAIN, gen_kwargs=dict( split="train", page_content=page_content_pcollection, split_filter_fn=c4_utils.get_hashed_url_filter_fn( predicate_fn=train_predicate_fn ) ), ), tfds.core.SplitGenerator( name=tfds.Split.VALIDATION, gen_kwargs=dict( split="validation", page_content=page_content_pcollection, split_filter_fn=c4_utils.get_hashed_url_filter_fn( predicate_fn=validation_predicate_fn ) ), ), ] splits = [] for lang in self.builder_config.languages + [c4_utils.UNKNOWN_LANGUAGE]: splits.extend([ tfds.core.SplitGenerator( name=lang, gen_kwargs=dict( split=lang, page_content=page_content_pcollection, split_filter_fn=functools.partial( _filter, lang=lang, predicate_fn=train_predicate_fn ), ) ), tfds.core.SplitGenerator( name=f"{lang}-validation", gen_kwargs=dict( split=f"{lang}-validation", page_content=page_content_pcollection, split_filter_fn=functools.partial( _filter, lang=lang, predicate_fn=validation_predicate_fn ), ) ) ]) return splits
def _filter(url_and_page, lang, predicate_fn): return (_lang_filter(url_and_page, lang) and c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page))
def _split_generators(self, dl_manager, pipeline): dl_manager.download_checksums(_CHECKSUMS_URL) # We will automatically download the first default CC version, but others # need to be manually downloaded. cc_versions = set(self.builder_config.cc_versions) files_to_download = {} files_to_download["wet_path_urls"] = [ _WET_PATH_URL.format(cc_version=cc_version) for cc_version in cc_versions ] if self.builder_config.badwords_filter: files_to_download["badwords"] = { lang: _BADWORDS_URL.format(lang=lang) for lang in _BADWORDS_LANGS if lang != "en" } # Use older "en" file for reproducibility of the original C4. files_to_download["badwords"]["en"] = _EN_BADWORDS_URL if self.builder_config.realnewslike: files_to_download["realnews_domains"] = _REALNEWS_DOMAINS_URL file_paths = dl_manager.download_and_extract(files_to_download) if self.builder_config.webtextlike: owt_path = os.path.join(dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP) if not tf.io.gfile.exists(owt_path): raise AssertionError( "For the WebText-like config, you must manually download the " "following file from {0} and place it in {1}: {2}".format( _OPENWEBTEXT_URLS_URL, dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP)) file_paths["openwebtext_urls_zip"] = dl_manager.extract(owt_path) file_paths = tf.nest.map_structure(os.fspath, file_paths) page_content_pcollection = self._get_page_content( pipeline, file_paths, dl_manager) def _lang_filter(url_and_page, lang): _, page = url_and_page return page["language"] == lang def _filter(url_and_page, lang, predicate_fn): return ( _lang_filter(url_and_page, lang) and c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page)) train_predicate_fn = lambda x: x % 1000 != 0 # 99.9% validation_predicate_fn = lambda x: x % 1000 == 0 # 00.1% if len(self.builder_config.languages) == 1: # Single-language version. return [ tfds.core.SplitGenerator( name=tfds.Split.TRAIN, gen_kwargs=dict( split="train", page_content=page_content_pcollection, split_filter_fn=c4_utils.get_hashed_url_filter_fn( predicate_fn=train_predicate_fn)), ), tfds.core.SplitGenerator( name=tfds.Split.VALIDATION, gen_kwargs=dict( split="validation", page_content=page_content_pcollection, split_filter_fn=c4_utils.get_hashed_url_filter_fn( predicate_fn=validation_predicate_fn)), ), ] splits = [] for lang in self.builder_config.languages + [ c4_utils.UNKNOWN_LANGUAGE ]: splits.extend([ tfds.core.SplitGenerator( name=lang, gen_kwargs=dict( split=lang, page_content=page_content_pcollection, split_filter_fn=functools.partial( _filter, lang=lang, predicate_fn=train_predicate_fn), )), tfds.core.SplitGenerator( name=f"{lang}-validation", gen_kwargs=dict( split=f"{lang}-validation", page_content=page_content_pcollection, split_filter_fn=functools.partial( _filter, lang=lang, predicate_fn=validation_predicate_fn), )) ]) return splits