Beispiel #1
0
  def _build_pcollection(
      self, unused_pipeline, split, page_content, hashed_url_predicate):
    beam = tfds.core.lazy_imports.apache_beam

    def _emit_examples(el):
      c4_utils.get_counter_inc_fn(split)("examples")
      _, features = el
      return features["url"], {
          "url": features["url"],
          "text": features["text"],
          "content-type": features["content-type"],
          "content-length": features["content-length"],
          "timestamp": features["timestamp"]
      }
    return (page_content
            | beam.Filter(
                c4_utils.get_hashed_url_filter_fn(hashed_url_predicate))
            | beam.Map(_emit_examples))
Beispiel #2
0
  def _split_generators(self, dl_manager, pipeline):
    dl_manager.download_checksums(_CHECKSUMS_URL)

    # We will automatically download the first default CC version, but others
    # need to be manually downloaded.
    cc_versions = set(self.builder_config.cc_versions)
    default_version = set([DEFAULT_CC_VERSION])
    auto_cc_versions = cc_versions & default_version
    manual_cc_versions = cc_versions - default_version

    files_to_download = {}
    files_to_download["wet_path_urls"] = [
        _WET_PATH_URL.format(cc_version=cc_version)
        for cc_version in auto_cc_versions]
    files_to_download["manual_wet_paths"] = {
        cc_version: _WET_PATH_URL.format(cc_version=cc_version)
        for cc_version in manual_cc_versions
    }
    if self.builder_config.badwords_filter:
      files_to_download["badwords"] = {
          lang: _BADWORDS_URL.format(lang=lang)
          for lang in _BADWORDS_LANGS if lang != "en"
      }
      # Use older "en" file for reproducibility of the original C4.
      files_to_download["badwords"]["en"] = _EN_BADWORDS_URL
    if self.builder_config.realnewslike:
      files_to_download["realnews_domains"] = _REALNEWS_DOMAINS_URL
    file_paths = dl_manager.download_and_extract(files_to_download)

    if self.builder_config.webtextlike:
      owt_path = os.path.join(dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP)
      if not tf.io.gfile.exists(owt_path):
        raise AssertionError(
            "For the WebText-like config, you must manually download the "
            "following file from {0} and place it in {1}: {2}".format(
                _OPENWEBTEXT_URLS_URL, dl_manager.manual_dir,
                _OPENWEBTEXT_URLS_ZIP))
      file_paths["openwebtext_urls_zip"] = dl_manager.extract(owt_path)

    wet_urls = []
    for wet_path_url in file_paths["wet_path_urls"]:
      with tf.io.gfile.GFile(wet_path_url) as f:
        wet_urls.extend(["%s/%s" % (_DOWNLOAD_HOST, l.strip()) for l in f])
    if dl_manager.register_checksums:
      # Download locally to register checksums.
      file_paths.update(dl_manager.download({"wet_files": wet_urls}))
    else:
      # Download on the beam workers.
      file_paths["wet_urls"] = wet_urls
      file_paths["wet_files"] = []

    for cc_version, wet_path_url in file_paths["manual_wet_paths"].items():
      crawl_dir = os.path.join(
          dl_manager.manual_dir, "crawl-data", f"CC-MAIN-{cc_version}")
      if not tf.io.gfile.exists(crawl_dir):
        raise AssertionError(
            "For the non-default Common Crawl version {0}, you must manually "
            "download the WET files to the directory {1}.".format(
                cc_version, crawl_dir))
      with tf.io.gfile.GFile(wet_path_url) as f:
        wet_files = [
            os.path.join(dl_manager.manual_dir, line.strip())
            for line in f if line.strip() not in _KNOWN_CORRUPT_WET_FILES
        ]
      logging.info(
          "Adding %d WET files for manually downloaded version %s.",
          len(wet_files), cc_version)
      file_paths["wet_files"].extend(wet_files)

    page_content_pcollection = self._get_page_content(
        pipeline, file_paths, dl_manager, self.builder_config.languages)

    def _lang_filter(url_and_page, lang):
      _, page = url_and_page
      return page["language"] == lang

    def _filter(url_and_page, lang, predicate_fn):
      return (_lang_filter(url_and_page, lang) and
              c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page))

    train_predicate_fn = lambda x: x % 1000 != 0  # 99.9%
    validation_predicate_fn = lambda x: x % 1000 == 0  # 00.1%

    if len(self.builder_config.languages) == 1:
      # Single-language version.
      return [
          tfds.core.SplitGenerator(
              name=tfds.Split.TRAIN,
              gen_kwargs=dict(
                  split="train",
                  page_content=page_content_pcollection,
                  split_filter_fn=c4_utils.get_hashed_url_filter_fn(
                      predicate_fn=train_predicate_fn
                  )
              ),
          ),
          tfds.core.SplitGenerator(
              name=tfds.Split.VALIDATION,
              gen_kwargs=dict(
                  split="validation",
                  page_content=page_content_pcollection,
                  split_filter_fn=c4_utils.get_hashed_url_filter_fn(
                      predicate_fn=validation_predicate_fn
                  )
              ),
          ),
      ]

    splits = []
    for lang in self.builder_config.languages + [c4_utils.UNKNOWN_LANGUAGE]:
      splits.extend([
          tfds.core.SplitGenerator(
              name=lang,
              gen_kwargs=dict(
                  split=lang,
                  page_content=page_content_pcollection,
                  split_filter_fn=functools.partial(
                      _filter, lang=lang,
                      predicate_fn=train_predicate_fn
                  ),
              )
          ),
          tfds.core.SplitGenerator(
              name=f"{lang}-validation",
              gen_kwargs=dict(
                  split=f"{lang}-validation",
                  page_content=page_content_pcollection,
                  split_filter_fn=functools.partial(
                      _filter, lang=lang,
                      predicate_fn=validation_predicate_fn
                  ),
              )
          )
      ])
    return splits
Beispiel #3
0
 def _filter(url_and_page, lang, predicate_fn):
   return (_lang_filter(url_and_page, lang) and
           c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page))
Beispiel #4
0
    def _split_generators(self, dl_manager, pipeline):
        dl_manager.download_checksums(_CHECKSUMS_URL)

        # We will automatically download the first default CC version, but others
        # need to be manually downloaded.
        cc_versions = set(self.builder_config.cc_versions)
        files_to_download = {}
        files_to_download["wet_path_urls"] = [
            _WET_PATH_URL.format(cc_version=cc_version)
            for cc_version in cc_versions
        ]
        if self.builder_config.badwords_filter:
            files_to_download["badwords"] = {
                lang: _BADWORDS_URL.format(lang=lang)
                for lang in _BADWORDS_LANGS if lang != "en"
            }
            # Use older "en" file for reproducibility of the original C4.
            files_to_download["badwords"]["en"] = _EN_BADWORDS_URL
        if self.builder_config.realnewslike:
            files_to_download["realnews_domains"] = _REALNEWS_DOMAINS_URL
        file_paths = dl_manager.download_and_extract(files_to_download)

        if self.builder_config.webtextlike:
            owt_path = os.path.join(dl_manager.manual_dir,
                                    _OPENWEBTEXT_URLS_ZIP)
            if not tf.io.gfile.exists(owt_path):
                raise AssertionError(
                    "For the WebText-like config, you must manually download the "
                    "following file from {0} and place it in {1}: {2}".format(
                        _OPENWEBTEXT_URLS_URL, dl_manager.manual_dir,
                        _OPENWEBTEXT_URLS_ZIP))
            file_paths["openwebtext_urls_zip"] = dl_manager.extract(owt_path)

        file_paths = tf.nest.map_structure(os.fspath, file_paths)

        page_content_pcollection = self._get_page_content(
            pipeline, file_paths, dl_manager)

        def _lang_filter(url_and_page, lang):
            _, page = url_and_page
            return page["language"] == lang

        def _filter(url_and_page, lang, predicate_fn):
            return (
                _lang_filter(url_and_page, lang) and
                c4_utils.get_hashed_url_filter_fn(predicate_fn)(url_and_page))

        train_predicate_fn = lambda x: x % 1000 != 0  # 99.9%
        validation_predicate_fn = lambda x: x % 1000 == 0  # 00.1%

        if len(self.builder_config.languages) == 1:
            # Single-language version.
            return [
                tfds.core.SplitGenerator(
                    name=tfds.Split.TRAIN,
                    gen_kwargs=dict(
                        split="train",
                        page_content=page_content_pcollection,
                        split_filter_fn=c4_utils.get_hashed_url_filter_fn(
                            predicate_fn=train_predicate_fn)),
                ),
                tfds.core.SplitGenerator(
                    name=tfds.Split.VALIDATION,
                    gen_kwargs=dict(
                        split="validation",
                        page_content=page_content_pcollection,
                        split_filter_fn=c4_utils.get_hashed_url_filter_fn(
                            predicate_fn=validation_predicate_fn)),
                ),
            ]

        splits = []
        for lang in self.builder_config.languages + [
                c4_utils.UNKNOWN_LANGUAGE
        ]:
            splits.extend([
                tfds.core.SplitGenerator(
                    name=lang,
                    gen_kwargs=dict(
                        split=lang,
                        page_content=page_content_pcollection,
                        split_filter_fn=functools.partial(
                            _filter,
                            lang=lang,
                            predicate_fn=train_predicate_fn),
                    )),
                tfds.core.SplitGenerator(
                    name=f"{lang}-validation",
                    gen_kwargs=dict(
                        split=f"{lang}-validation",
                        page_content=page_content_pcollection,
                        split_filter_fn=functools.partial(
                            _filter,
                            lang=lang,
                            predicate_fn=validation_predicate_fn),
                    ))
            ])
        return splits