예제 #1
0
    def GetContentFileRoot(self, config: corpus_pb2.Corpus) -> pathlib.Path:
        """Get the path of the directory containing content files.

    If the corpus is a local directory, this simply returns the path. Otherwise,
    this method creates a temporary copy of the files which can be used within
    the scope of this context.

    Args:
      config: The corpus config proto.

    Returns:
      The path of a directory containing content files.
    """
        if config.HasField('local_directory'):
            yield pathlib.Path(ExpandConfigPath(config.local_directory))
        elif config.HasField('local_tar_archive'):
            with tempfile.TemporaryDirectory(prefix='clgen_corpus_') as d:
                start_time = time.time()
                cmd = [
                    'tar', '-xf',
                    str(ExpandConfigPath(config.local_tar_archive)), '-C', d
                ]
                subprocess.check_call(cmd)
                logging.info(
                    'Unpacked %s in %s ms',
                    ExpandConfigPath(config.local_tar_archive).name,
                    humanize.intcomma(int((time.time() - start_time) * 1000)))
                yield pathlib.Path(d)
        else:
            raise NotImplementedError
예제 #2
0
def ResolveContentId(config: corpus_pb2.Corpus,
                     hc: hashcache.HashCache) -> str:
    """Compute the hash of the input contentfiles.

  This function resolves the unique sha1 checksum of a set of content files.

  Args:
    config: The corpus config proto.
    hc: A hashcache database instance, used for resolving directory hashes.

  Returns:
    A hex encoded sha1 string.
  """
    # We can take a massive shortcut if the content ID is already set in the
    # config proto.
    if config.HasField('content_id'):
        return config.content_id

    start_time = time.time()
    if config.HasField('local_directory'):
        # After the first time we compute the hash of a directory, we write it into
        # a file. This is a shortcut to work around the fact that computing the
        # directory checksum is O(n) with respect to the number of files in the
        # directory (even if the directory is already cached by the hash cache).
        # This means that it is the responsibility of the user to delete this cached
        # file if the directory is changed.
        hash_file_path = pathlib.Path(
            str(pathlib.Path(config.local_directory)) + '.sha1.txt')
        if hash_file_path.is_file():
            logging.info("Reading directory hash: '%s'.", hash_file_path)
            with open(hash_file_path) as f:
                content_id = f.read().rstrip()
        else:
            # No hash file, so compute the directory hash and create it.
            try:
                content_id = hc.GetHash(
                    ExpandConfigPath(
                        config.local_directory,
                        path_prefix=FLAGS.clgen_local_path_prefix))
            except FileNotFoundError as e:
                raise errors.UserError(e)
            # Create the hash file in the directory so that next time we don't need
            # to reference the hash cache.
            with open(hash_file_path, 'w') as f:
                print(content_id, file=f)
            logging.info("Wrote directory hash: '%s'.", hash_file_path)
    elif config.HasField('local_tar_archive'):
        # This if not an efficient means of getting the hash, as it requires always
        # unpacking the archive and reading the entire contents. It would be nicer
        # to maintain a cache which maps the mtime of tarballs to their content ID,
        # similart to how local_directory is implemented.
        content_id = GetHashOfArchiveContents(
            ExpandConfigPath(config.local_tar_archive,
                             path_prefix=FLAGS.clgen_local_path_prefix))
    else:
        raise NotImplementedError(
            'Unsupported Corpus.contentfiles field value')
    logging.debug('Resolved Content ID %s in %s ms.', content_id,
                  humanize.intcomma(int((time.time() - start_time) * 1000)))
    return content_id
예제 #3
0
  def GetContentFileRoot(self, config: corpus_pb2.Corpus) -> pathlib.Path:
    """Get the path of the directory containing content files.

    If the corpus is a local directory, this simply returns the path. Otherwise,
    this method creates a temporary copy of the files which can be used within
    the scope of this context.

    Args:
      config: The corpus config proto.

    Returns:
      The path of a directory containing content files.
    """
    if config.HasField("local_directory"):
      yield pathlib.Path(ExpandConfigPath(config.local_directory))
    elif config.HasField("local_tar_archive"):
      with tempfile.TemporaryDirectory(prefix="clgen_corpus_") as d:
        start_time = time.time()
        cmd = [
          "tar",
          "-xf",
          str(ExpandConfigPath(config.local_tar_archive)),
          "-C",
          d,
        ]
        subprocess.check_call(cmd)
        app.Log(
          1,
          "Unpacked %s in %s ms",
          ExpandConfigPath(config.local_tar_archive).name,
          humanize.Commas(int((time.time() - start_time) * 1000)),
        )
        yield pathlib.Path(d)
    else:
      raise NotImplementedError
예제 #4
0
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus:
    """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
    try:
        pbutil.AssertFieldIsSet(config, 'contentfiles')
        pbutil.AssertFieldIsSet(config, 'atomizer')
        pbutil.AssertFieldIsSet(config, 'contentfile_separator')
        # Check that the preprocessor pipeline resolves to preprocessor functions.
        [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

        if config.HasField('greedy_multichar_atomizer'):
            if not config.greedy_multichar_atomizer.tokens:
                raise errors.UserError(
                    'GreedyMulticharAtomizer.tokens is empty')
            for atom in config.greedy_multichar_atomizer.tokens:
                if not atom:
                    raise errors.UserError(
                        'Empty string found in GreedyMulticharAtomizer.tokens is empty'
                    )

        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
예제 #5
0
  def GetContentFileRoot(self, config: corpus_pb2.Corpus) -> pathlib.Path:
    """Get the path of the directory containing content files.

    If the corpus is a local directory, this simply returns the path. Otherwise,
    this method creates a temporary copy of the files which can be used within
    the scope of this context.

    Args:
      config: The corpus config proto.

    Returns:
      The path of a directory containing content files.
    """
    if config.HasField("local_directory"):
      yield pathlib.Path(ExpandConfigPath(config.local_directory))
    elif config.HasField("local_tar_archive"):
      with tempfile.TemporaryDirectory(prefix="clgen_corpus_", dir = FLAGS.local_filesystem) as d:
        l.logger().info("Unpacking {}...".format(ExpandConfigPath(config.local_tar_archive).name))
        start_time = time.time()
        if environment.WORLD_RANK == 0:
          cmd = [
            "tar",
            "-xf",
            str(ExpandConfigPath(config.local_tar_archive)),
            "-C",
            d,
          ]
          subprocess.check_call(cmd)
        distrib.barrier()
        l.logger().info(
          "Unpacked {} in {} ms".format(
                  ExpandConfigPath(config.local_tar_archive).name,
                  humanize.intcomma(int((time.time() - start_time) * 1000)),
              )
        )
        yield pathlib.Path(d)
    elif config.HasField("bq_database"):
      input_bq = pathlib.Path(ExpandConfigPath(config.bq_database))
      if environment.WORLD_SIZE > 1:
        target_bq = self.replicated_path.parent / "bq_database_replica_{}.db".format(environment.WORLD_RANK)
        if not target_bq.exists():
          shutil.copy(input_bq, target_bq)
        yield target_bq
      else:
        yield input_bq
    else:
      raise NotImplementedError
예제 #6
0
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus:
    """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
    try:
        # Early-exit to support corpuses derived from databases of pre-encoded
        # content files.
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor after splitting
        # Corpus class.
        if config.HasField("pre_encoded_corpus_url"):
            return config

        pbutil.AssertFieldIsSet(config, "contentfiles")
        pbutil.AssertFieldIsSet(config, "atomizer")
        pbutil.AssertFieldIsSet(config, "contentfile_separator")
        # Check that the preprocessor pipeline resolves to preprocessor functions.
        [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

        if config.HasField("greedy_multichar_atomizer"):
            if not config.greedy_multichar_atomizer.tokens:
                raise errors.UserError(
                    "GreedyMulticharAtomizer.tokens is empty")
            for atom in config.greedy_multichar_atomizer.tokens:
                if not atom:
                    raise errors.UserError(
                        "Empty string found in GreedyMulticharAtomizer.tokens is empty"
                    )

        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
예제 #7
0
def ResolveContentId(config: corpus_pb2.Corpus,
                     hc: typing.Optional[hashcache.HashCache] = None) -> str:
    """Compute the hash of the input contentfiles.

  This function resolves the unique sha1 checksum of a set of content files.

  Args:
    config: The corpus config proto.
    hc: A hashcache database instance, used for resolving directory hashes. If
      the corpus has pre_encoded_corpus_url field set, this may be omitted.

  Returns:
    A hex encoded sha1 string.
  """
    # We can take a massive shortcut if the content ID is already set in the
    # config proto.
    if config.HasField("content_id"):
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this after splitting
        # out Corpus class.
        return config.content_id
    elif config.HasField("pre_encoded_corpus_url"):
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this after splitting
        # out Corpus class.
        return crypto.sha1_str(config.pre_encoded_corpus_url)

    start_time = time.time()
    if config.HasField("local_directory"):
        local_directory = ExpandConfigPath(
            config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix)

        # After the first time we compute the hash of a directory, we write it into
        # a file. This is a shortcut to work around the fact that computing the
        # directory checksum is O(n) with respect to the number of files in the
        # directory (even if the directory is already cached by the hash cache).
        # This means that it is the responsibility of the user to delete this cached
        # file if the directory is changed.
        hash_file_path = pathlib.Path(str(local_directory) + ".sha1.txt")
        if hash_file_path.is_file():
            app.Log(1, "Reading directory hash: '%s'.", hash_file_path)
            with open(hash_file_path) as f:
                content_id = f.read().rstrip()
        else:
            # No hash file, so compute the directory hash and create it.
            try:
                content_id = hc.GetHash(local_directory)
            except FileNotFoundError as e:
                raise errors.UserError(e)
            # Create the hash file in the directory so that next time we don't need
            # to reference the hash cache.
            with open(hash_file_path, "w") as f:
                print(content_id, file=f)
            app.Log(1, "Wrote directory hash: '%s'.", hash_file_path)
    elif config.HasField("local_tar_archive"):
        # This if not an efficient means of getting the hash, as it requires always
        # unpacking the archive and reading the entire contents. It would be nicer
        # to maintain a cache which maps the mtime of tarballs to their content ID,
        # similart to how local_directory is implemented.
        content_id = GetHashOfArchiveContents(
            ExpandConfigPath(config.local_tar_archive,
                             path_prefix=FLAGS.clgen_local_path_prefix))
    else:
        raise NotImplementedError(
            "Unsupported Corpus.contentfiles field value")
    app.Log(
        2,
        "Resolved Content ID %s in %s ms.",
        content_id,
        humanize.Commas(int((time.time() - start_time) * 1000)),
    )
    return content_id
예제 #8
0
    def __init__(self, config: corpus_pb2.Corpus):
        """Instantiate a corpus from a proto config.

    If this is a new corpus, a number of files will be created, which may
    take some time.

    Args:
      config: A Corpus message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: In case the corpus is not found, or config contains invalid
        options.
      EmptyCorpusException: In case the corpus contains no data.
    """
        if not isinstance(config, corpus_pb2.Corpus):
            t = type(config).__name__
            raise TypeError(f"Config must be a Corpus proto. Received: '{t}'")

        # Make a local copy of the configuration.
        self.config = corpus_pb2.Corpus()
        self.config.CopyFrom(AssertConfigIsValid(config))
        self._atomizer = None
        self._created = False
        self.dashboard_db = dashboard_db.GetDatabase()
        self._dashboard_db_id: typing.Optional[int] = None  # Set in Create()

        # An in-memory cache of the encoded contentfiles indices arrays.
        # Set and used in GetTrainingData().
        self._indices_arrays: typing.Optional[typing.List[np.array]] = None

        cache.cachepath("corpus").mkdir(parents=True, exist_ok=True)
        hc = hashcache.HashCache(cache.cachepath("hashcache.db"), "sha1")
        self.content_id = ResolveContentId(self.config, hc)
        # Database of pre-processed files.
        preprocessed_id = ResolvePreprocessedId(self.content_id, self.config)
        cache.cachepath("corpus", "preprocessed",
                        preprocessed_id).mkdir(exist_ok=True, parents=True)
        preprocessed_db_path = cache.cachepath("corpus", "preprocessed",
                                               preprocessed_id,
                                               "preprocessed.db")
        if (self.config.HasField("content_id")
                and not preprocessed_db_path.is_file()):
            raise errors.UserError(
                f"Content ID not found: '{self.content_id}'")
        self.preprocessed = preprocessed.PreprocessedContentFiles(
            f"sqlite:///{preprocessed_db_path}")
        # Create symlink to contentfiles.
        symlink = (
            pathlib.Path(self.preprocessed.url[len("sqlite:///"):]).parent /
            "contentfiles")
        if not symlink.is_symlink():
            if config.HasField("local_directory"):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_directory,
                            path_prefix=FLAGS.clgen_local_path_prefix)),
                    symlink,
                )
            elif config.HasField("local_tar_archive"):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_tar_archive,
                            path_prefix=FLAGS.clgen_local_path_prefix,
                        )),
                    symlink,
                )
        # Data of encoded pre-preprocessed files.
        encoded_id = ResolveEncodedId(self.content_id, self.config)
        cache.cachepath("corpus", "encoded", encoded_id).mkdir(exist_ok=True,
                                                               parents=True)
        db_path = cache.cachepath("corpus", "encoded", encoded_id,
                                  "encoded.db")
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this conditional
        # logic by making Corpus an abstract class and creating concrete subclasses
        # for the different types of corpus.
        if self.config.HasField("pre_encoded_corpus_url"):
            self.encoded = encoded.EncodedContentFiles(
                config.pre_encoded_corpus_url)
        else:
            self.encoded = encoded.EncodedContentFiles(f"sqlite:///{db_path}")
        self.atomizer_path = cache.cachepath("corpus", "encoded", encoded_id,
                                             "atomizer.pkl")
        # Create symlink to preprocessed files.
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this conditional
        # logic after splitting Corpus class.
        if not self.config.HasField("pre_encoded_corpus_url"):
            symlink = (
                pathlib.Path(self.encoded.url[len("sqlite:///"):]).parent /
                "preprocessed")
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(
                        pathlib.Path(
                            self.preprocessed.url[len("sqlite:///"):]).parent,
                        pathlib.Path(
                            self.encoded.url[len("sqlite:///"):]).parent,
                    ),
                    symlink,
                )
        self.hash = encoded_id
        self.cache = cache.mkcache("corpus", "encoded", encoded_id)
예제 #9
0
    def __init__(self, config: corpus_pb2.Corpus):
        """Instantiate a corpus from a proto config.

    If this is a new corpus, a number of files will be created, which may
    take some time.

    Args:
      config: A Corpus message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: In case the corpus is not found, or config contains invalid
        options.
      EmptyCorpusException: In case the corpus contains no data.
    """
        if not isinstance(config, corpus_pb2.Corpus):
            t = type(config).__name__
            raise TypeError(f"Config must be a Corpus proto. Received: '{t}'")

        # Make a local copy of the configuration.
        self.config = corpus_pb2.Corpus()
        self.config.CopyFrom(AssertConfigIsValid(config))
        self._atomizer = None
        self._created = False

        cache.cachepath('corpus').mkdir(parents=True, exist_ok=True)
        hc = hashcache.HashCache(cache.cachepath('hashcache.db'), 'sha1')
        self.content_id = ResolveContentId(self.config, hc)
        # Database of pre-processed files.
        preprocessed_id = ResolvePreprocessedId(self.content_id, self.config)
        cache.cachepath('corpus', 'preprocessed',
                        preprocessed_id).mkdir(exist_ok=True, parents=True)
        preprocessed_db_path = cache.cachepath('corpus', 'preprocessed',
                                               preprocessed_id,
                                               'preprocessed.db')
        if (self.config.HasField('content_id')
                and not preprocessed_db_path.is_file()):
            raise errors.UserError(
                f"Content ID not found: '{self.content_id}'")
        self.preprocessed = preprocessed.PreprocessedContentFiles(
            preprocessed_db_path)
        # Create symlink to contentfiles.
        symlink = self.preprocessed.database_path.parent / 'contentfiles'
        if not symlink.is_symlink():
            if config.HasField('local_directory'):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_directory,
                            path_prefix=FLAGS.clgen_local_path_prefix)),
                    symlink)
            elif config.HasField('local_tar_archive'):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_tar_archive,
                            path_prefix=FLAGS.clgen_local_path_prefix)),
                    symlink)
        # Data of encoded pre-preprocessed files.
        encoded_id = ResolveEncodedId(self.content_id, self.config)
        cache.cachepath('corpus', 'encoded', encoded_id).mkdir(exist_ok=True,
                                                               parents=True)
        self.encoded = encoded.EncodedContentFiles(
            cache.cachepath('corpus', 'encoded', encoded_id, 'encoded.db'))
        self.atomizer_path = cache.cachepath('corpus', 'encoded', encoded_id,
                                             'atomizer.pkl')
        # Create symlink to preprocessed files.
        symlink = self.encoded.database_path.parent / 'preprocessed'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(self.preprocessed.database_path.parent,
                                self.encoded.database_path.parent), symlink)
        self.hash = encoded_id
        self.cache = cache.mkcache('corpus', 'encoded', encoded_id)
예제 #10
0
  def Import(self, session: sqlutil.Session, config: corpus_pb2.Corpus) -> None:
    with self.GetContentFileRoot(config) as contentfile_root:
      if not config.HasField("bq_database"):
        if environment.WORLD_RANK == 0:
          relpaths = set(self.GetImportRelpaths(contentfile_root))
          done = set(
            [x[0] for x in session.query(PreprocessedContentFile.input_relpath)]
          )
          todo = relpaths - done
          l.logger().info(
            "Preprocessing {} of {} content files".format(
                    humanize.intcomma(len(todo)),
                    humanize.intcomma(len(relpaths)),
                )
          )
          chunk_size = 100000
          jobs, total = [], 0
          for idx, t in enumerate(todo):
            if idx % chunk_size == 0:
              jobs.append([t])
            else:
              jobs[-1].append(t)
            total += 1
          bar = tqdm.tqdm(total = total, desc = "Preprocessing", leave = True)
          c = 0
          last_commit = time.time()
          wall_time_start = time.time()
          for job_chunk in jobs:
            try:
              pool = multiprocessing.Pool()
              for preprocessed_list in pool.imap_unordered(
                                         functools.partial(
                                           PreprocessorWorker,
                                           contentfile_root = contentfile_root,
                                           preprocessors = list(config.preprocessor)
                                         ),
                                         job_chunk
                                        ):
                for preprocessed_cf in preprocessed_list:
                  wall_time_end = time.time()
                  preprocessed_cf.wall_time_ms = int(
                    (wall_time_end - wall_time_start) * 1000
                  )
                  wall_time_start = wall_time_end
                  session.add(preprocessed_cf)
                  if wall_time_end - last_commit > 10:
                    session.commit()
                    last_commit = wall_time_end
                c += 1
                bar.update(1)
              pool.close()
            except KeyboardInterrupt as e:
              pool.terminate()
              raise e
            except Exception as e:
              pool.terminate()
              raise e
          session.commit()
      else:
        db  = bqdb.bqDatabase("sqlite:///{}".format(contentfile_root), must_exist = True)
        total = db.mainfile_count                        # Total number of files in BQ database.
        total_per_node = total // environment.WORLD_SIZE # In distributed nodes, this is the total files to be processed per node.
        if total == 0:
          raise ValueError("Input BQ database {} is empty!".format(contentfile_root))

        # Set of IDs that have been completed.
        done = set(
          [x[0].replace("main_files/", "") for x in session.query(PreprocessedContentFile.input_relpath)]
        )

        chunk, idx = min(total_per_node, 100000), environment.WORLD_RANK * total_per_node
        limit = (environment.WORLD_RANK + 1) * total_per_node + (total % total_per_node if environment.WORLD_RANK == environment.WORLD_SIZE - 1 else 0)

        if environment.WORLD_SIZE > 1:
          bar = distrib.ProgressBar(total = total, offset = idx, decs = "Preprocessing DB")
        else:
          bar = tqdm.tqdm(total = total, desc = "Preprocessing DB", leave = True)

        last_commit     = time.time()
        wall_time_start = time.time()

        while idx < limit:
          try:
            chunk = min(chunk, limit - idx) # This is equivalent to l447/l448 but needed for last node that gets a bit more.
            batch = db.main_files_batch(chunk, idx, exclude_id = done)
            idx += chunk - len(batch) # This difference will be the number of already done files.
            pool = multiprocessing.Pool()
            for preprocessed_list in pool.imap_unordered(
                                      functools.partial(
                                        BQPreprocessorWorker,
                                        preprocessors = list(config.preprocessor)
                                    ), batch):
              for preprocessed_cf in preprocessed_list:
                wall_time_end = time.time()
                preprocessed_cf.wall_time_ms = int(
                  (wall_time_end - wall_time_start) * 1000
                )
                wall_time_start = wall_time_end
                session.add(preprocessed_cf)
                if wall_time_end - last_commit > 10:
                  session.commit()
                  last_commit = wall_time_end
              idx += 1
              bar.update(idx - bar.n)
            pool.close()
          except KeyboardInterrupt as e:
            pool.terminate()
            raise e
          except Exception as e:
            l.logger().error(e, ddp_nodes = True)
            pool.terminate()
            raise e
        session.commit()
        if environment.WORLD_SIZE > 1:
          bar.finalize(idx)
    return