コード例 #1
0
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus:
    """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
    try:
        pbutil.AssertFieldIsSet(config, 'contentfiles')
        pbutil.AssertFieldIsSet(config, 'atomizer')
        pbutil.AssertFieldIsSet(config, 'contentfile_separator')
        # Check that the preprocessor pipeline resolves to preprocessor functions.
        [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

        if config.HasField('greedy_multichar_atomizer'):
            if not config.greedy_multichar_atomizer.tokens:
                raise errors.UserError(
                    'GreedyMulticharAtomizer.tokens is empty')
            for atom in config.greedy_multichar_atomizer.tokens:
                if not atom:
                    raise errors.UserError(
                        'Empty string found in GreedyMulticharAtomizer.tokens is empty'
                    )

        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
コード例 #2
0
def GetPreprocessorFunction(name: str) -> public.PreprocessorFunction:
    """Lookup a preprocess function by name.

  A preprocessor is a function which takes a single argument 'text' of type str,
  and returns a str. The name is the fully qualified name of the python
  function which implements it, in the form <module>:<name>. For example,
  the name 'deeplearning.clgen.preprocessors.cxx:Compile' will return the
  function 'Compile' in the module 'deeplearning.clgen.preprocessors.cxx'.

  Args:
    name: The name of the preprocessor to get.

  Returns:
    The python preprocessor function.

  Raises:
    UserError: If the requested name cannot be found or is not a
      @clgen_preprocessor decorated function.
  """
    components = name.split(':')
    if len(components) != 2:
        raise errors.UserError(f'Invalid preprocessor name {name}')
    module_name, function_name = components
    try:
        module = importlib.import_module(module_name)
        function_ = getattr(module, function_name)
    except (ModuleNotFoundError, AttributeError):
        raise errors.UserError(f'Preprocessor {name} not found.')
    if not function_.__dict__.get('is_clgen_preprocessor'):
        raise errors.UserError(
            f'Preprocessor {name} not decorated with @clgen_preprocessor')
    return function_
コード例 #3
0
ファイル: samplers.py プロジェクト: SpringRi/phd
 def __init__(self, config: sampler_pb2.SymmetricalTokenDepth):
     try:
         self.left_token = pbutil.AssertFieldConstraint(
             config, 'depth_increase_token', lambda s: len(s) > 0,
             'SymmetricalTokenDepth.depth_increase_token must be a string')
         self.right_token = pbutil.AssertFieldConstraint(
             config, 'depth_decrease_token', lambda s: len(s) > 0,
             'SymmetricalTokenDepth.depth_decrease_token must be a string')
     except pbutil.ProtoValueError as e:
         raise errors.UserError(e)
     if self.left_token == self.right_token:
         raise errors.UserError(
             'SymmetricalTokenDepth tokens must be different')
コード例 #4
0
def _ImportPreprocessorFromModule(module_name: str, function_name: str):
    """Import module from a fully qualified module name, e.g. 'foo.bar'."""
    try:
        module = importlib.import_module(module_name)
    except (ModuleNotFoundError, AttributeError):
        raise errors.UserError(f'Module {module_name} not found.')
    if not hasattr(module, function_name):
        raise errors.UserError(
            f'Function {function_name} not found in module {module_name}')
    function_ = getattr(module, function_name)
    if not function_.__dict__.get('is_clgen_preprocessor'):
        raise errors.UserError(
            f'Preprocessor {function_name} not decorated with @clgen_preprocessor'
        )
    return function_
コード例 #5
0
def _ImportPreprocessorFromFile(module_path: pathlib.Path, function_name: str):
    """Import module from an absolute path to file, e.g. '/foo/bar.py'."""
    if not module_path.is_file():
        raise errors.UserError(f"File not found: {module_path}")
    try:
        spec = importlib_util.spec_from_file_location("module",
                                                      str(module_path))
        module = importlib_util.module_from_spec(spec)
        spec.loader.exec_module(module)
    except ImportError as e:
        raise errors.UserError(f'Failed to import module {module_path}: {e}')
    if not hasattr(module, function_name):
        raise errors.UserError(
            f'Function {function_name} not found in module {module_path}')
    return getattr(module, function_name)
コード例 #6
0
def GetPreprocessorFunction(name: str) -> public.PreprocessorFunction:
    """Lookup a preprocess function by name.

  A preprocessor is a function which takes a single argument 'text' of type str,
  and returns a str. The name is in the form <module>:<name>, where <name> is
  the name of a python function, and <module> is either a fully qualified module
  name, or an absolute path to the module file. For example, the name
  'deeplearning.clgen.preprocessors.cxx:Compile' will return the function
  'Compile' in the module 'deeplearning.clgen.preprocessors.cxx'. The name
  '/tmp/my_preprocessors.py:Transform' will return the function Transform() in
  the module defined at '/tmp/my_preprocessors.py'.

  Args:
    name: The name of the preprocessor to get.

  Returns:
    The python preprocessor function.

  Raises:
    UserError: If the requested name cannot be found or is not a
      @clgen_preprocessor decorated function.
  """
    components = name.split(':')
    if len(components) != 2:
        raise errors.UserError(f'Invalid preprocessor name {name}')
    module_name, function_name = components
    if module_name[0] == '/':
        return _ImportPreprocessorFromFile(pathlib.Path(module_name),
                                           function_name)
    else:
        return _ImportPreprocessorFromModule(module_name, function_name)
コード例 #7
0
ファイル: samplers.py プロジェクト: SpringRi/phd
 def __init__(self, config: sampler_pb2.MaxTokenLength):
     try:
         self.max_len = pbutil.AssertFieldConstraint(
             config, 'maximum_tokens_in_sample', lambda x: x > 1,
             'MaxTokenLength.maximum_tokens_in_sample must be > 0')
     except pbutil.ProtoValueError as e:
         raise errors.UserError(e)
コード例 #8
0
ファイル: samplers.py プロジェクト: zhangheyu518/clgen
def AssertConfigIsValid(config: sampler_pb2.Sampler) -> sampler_pb2.Sampler:
    """Assert that a sampler configuration contains no invalid values.

  Args:
    config: A sampler configuration proto.

  Returns:
    The sampler configuration proto.

  Raises:
    UserError: If there are configuration errors.
  """
    try:
        pbutil.AssertFieldConstraint(config, 'start_text', lambda s: len(s),
                                     'Sampler.start_text must be a string')
        pbutil.AssertFieldConstraint(config, 'batch_size', lambda x: 0 < x,
                                     'Sampler.batch_size must be > 0')
        pbutil.AssertFieldConstraint(config, 'sequence_length',
                                     lambda x: 0 < x,
                                     'Sampler.sequence_length must be > 0')
        pbutil.AssertFieldConstraint(config, 'temperature_micros',
                                     lambda x: 0 < x,
                                     'Sampler.temperature_micros must be > 0')
        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
コード例 #9
0
    def Sample(
        self,
        sampler: samplers.Sampler,
        sample_observers: typing.List[sample_observers_lib.SampleObserver],
        seed: int = None,
    ) -> None:
        """Sample a model.

    This method uses the observer model, returning nothing. To access the
    samples produced, implement a SampleObserver and pass it in as an argument.
    Sampling continues indefinitely until one of the sample observers returns
    False when notified of a new sample.

    If the model is not already trained, calling Sample() first trains the
    model. Thus a call to Sample() is equivalent to calling Train() then
    Sample().

    Args:
      sampler: The sampler to sample using.
      sample_observers: A list of SampleObserver objects that are notified of
        new generated samples.
      seed: A numeric value to seed the RNG with. If not present, the RNG is
        seeded randomly.

    Raises:
      UserError: If called with no sample observers.
      UnableToAcquireLockError: If the model is locked (i.e. there is another
        process currently modifying the model).
      InvalidStartText: If the sampler start text cannot be encoded.
      InvalidSymtokTokens: If the sampler symmetrical depth tokens cannot be
        encoded.
    """
        if not sample_observers:
            raise errors.UserError("Cannot sample without any observers")

        sample_start_time = labdate.MillisecondsTimestamp()

        self.Train()

        with logutil.TeeLogsToFile(f"sampler_{sampler.hash}",
                                   self.cache.path / "logs"):
            app.Log(1, "Sampling: '%s'", sampler.start_text)

            atomizer = self.corpus.atomizer
            sampler.Specialize(atomizer)
            self.backend.InitSampling(sampler, seed)
            [obs.Specialize(self, sampler) for obs in sample_observers]

            batch_count = 1
            while self._SampleBatch(sampler, atomizer, sample_observers):
                batch_count += 1

            time_now = labdate.MillisecondsTimestamp()
            app.Log(
                1,
                "Produced %s sample batches at a rate of %s ms / batch.",
                humanize.Commas(batch_count),
                humanize.Commas(
                    int((time_now - sample_start_time) / max(batch_count, 1))),
            )
コード例 #10
0
ファイル: sample.py プロジェクト: BeauJoh/phd
    def __init__(self, config: clgen_pb2.Instance):
        """Instantiate an instance.

    Args:
      config: An Instance proto.

    Raises:
      UserError: If the instance proto contains invalid values, is missing
        a model or sampler fields.
    """
        try:
            pbutil.AssertFieldIsSet(config, 'pretrained_model')
            pbutil.AssertFieldIsSet(config, 'sampler')
        except pbutil.ProtoValueError as e:
            raise errors.UserError(e)

        self.working_dir = None
        if config.HasField('working_dir'):
            self.working_dir: pathlib.Path = pathlib.Path(
                os.path.expandvars(
                    config.working_dir)).expanduser().absolute()
        # Enter a session so that the cache paths are set relative to any requested
        # working directory.
        with self.Session():
            self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
                pathlib.Path(config.pretrained_model))
            self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
コード例 #11
0
def ResolveContentId(config: corpus_pb2.Corpus,
                     hc: hashcache.HashCache) -> str:
    """Compute the hash of the input contentfiles.

  This function resolves the unique sha1 checksum of a set of content files.

  Args:
    config: The corpus config proto.
    hc: A hashcache database instance, used for resolving directory hashes.

  Returns:
    A hex encoded sha1 string.
  """
    # We can take a massive shortcut if the content ID is already set in the
    # config proto.
    if config.HasField('content_id'):
        return config.content_id

    start_time = time.time()
    if config.HasField('local_directory'):
        # After the first time we compute the hash of a directory, we write it into
        # a file. This is a shortcut to work around the fact that computing the
        # directory checksum is O(n) with respect to the number of files in the
        # directory (even if the directory is already cached by the hash cache).
        # This means that it is the responsibility of the user to delete this cached
        # file if the directory is changed.
        hash_file_path = pathlib.Path(
            str(pathlib.Path(config.local_directory)) + '.sha1.txt')
        if hash_file_path.is_file():
            logging.info("Reading directory hash: '%s'.", hash_file_path)
            with open(hash_file_path) as f:
                content_id = f.read().rstrip()
        else:
            # No hash file, so compute the directory hash and create it.
            try:
                content_id = hc.GetHash(
                    ExpandConfigPath(
                        config.local_directory,
                        path_prefix=FLAGS.clgen_local_path_prefix))
            except FileNotFoundError as e:
                raise errors.UserError(e)
            # Create the hash file in the directory so that next time we don't need
            # to reference the hash cache.
            with open(hash_file_path, 'w') as f:
                print(content_id, file=f)
            logging.info("Wrote directory hash: '%s'.", hash_file_path)
    elif config.HasField('local_tar_archive'):
        # This if not an efficient means of getting the hash, as it requires always
        # unpacking the archive and reading the entire contents. It would be nicer
        # to maintain a cache which maps the mtime of tarballs to their content ID,
        # similart to how local_directory is implemented.
        content_id = GetHashOfArchiveContents(
            ExpandConfigPath(config.local_tar_archive,
                             path_prefix=FLAGS.clgen_local_path_prefix))
    else:
        raise NotImplementedError(
            'Unsupported Corpus.contentfiles field value')
    logging.debug('Resolved Content ID %s in %s ms.', content_id,
                  humanize.intcomma(int((time.time() - start_time) * 1000)))
    return content_id
コード例 #12
0
    def __init__(self, config: clgen_pb2.Instance, dashboard_opts={}):
        """Instantiate an instance.

    Args:
      config: An Instance proto.

    Raises:
      UserError: If the instance proto contains invalid values, is missing
        a model or sampler fields.
    """
        try:
            pbutil.AssertFieldIsSet(config, "model_specification")
            pbutil.AssertFieldIsSet(config, "sampler")
        except pbutil.ProtoValueError as e:
            raise errors.UserError(e)

        self.config = config
        self.working_dir = None
        if config.HasField("working_dir"):
            self.working_dir: pathlib.Path = pathlib.Path(
                os.path.expandvars(
                    config.working_dir)).expanduser().absolute()
        # Enter a session so that the cache paths are set relative to any requested
        # working directory.
        with self.Session():
            if config.HasField("model"):
                self.model: models.Model = models.Model(config.model)
            else:
                self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
                    pathlib.Path(config.pretrained_model))
            self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)

        self.dashboard = dashboard.Launch(**dashboard_opts)
コード例 #13
0
ファイル: data_generators.py プロジェクト: whatsmyname/clgen
def GetTrainingCorpus(
    corpus: "corpuses.Corpus", training_opts: model_pb2.TrainingOptions
) -> typing.Tuple[np.ndarray, np.ndarray, int]:
    """Get the corpus to train over.

  Args:
    corpus: A Corpus instance.
    training_opts: A TrainingOptions proto.

  Returns:
    An X, y pair of data for an epoch, and the number of steps in the epoch.

  Raises:
    UserError: If batch_size and sequence_length are too large for the corpus,
      yielding no batches.
  """
    start_time = time.time()
    encoded_corpus = corpus.GetTrainingData(
        shuffle=training_opts.shuffle_corpus_contentfiles_between_epochs)
    corpus_length = len(encoded_corpus)
    steps_per_epoch = (corpus_length - 1) // (training_opts.batch_size *
                                              training_opts.sequence_length)
    if not steps_per_epoch:
        raise errors.UserError(
            f"Requested batch size ({training_opts.batch_size}) and "
            f"sequence length ({training_opts.sequence_length}) are too large for "
            f"corpus of size {corpus_length}.")

    clipped_corpus_length = (steps_per_epoch * training_opts.batch_size *
                             training_opts.sequence_length)

    x = np.reshape(
        encoded_corpus[:clipped_corpus_length],
        [
            training_opts.batch_size,
            steps_per_epoch * training_opts.sequence_length
        ],
    )
    y = np.reshape(
        encoded_corpus[1:clipped_corpus_length + 1],
        [
            training_opts.batch_size,
            steps_per_epoch * training_opts.sequence_length
        ],
    )

    app.Log(
        1,
        "Encoded corpus of %s tokens (clipped last %s tokens) in %s ms.",
        humanize.Commas(clipped_corpus_length),
        humanize.Commas(corpus_length - clipped_corpus_length),
        humanize.Commas(int((time.time() - start_time) * 1000)),
    )
    return x, y, steps_per_epoch
コード例 #14
0
ファイル: corpuses.py プロジェクト: whatsmyname/clgen
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus:
    """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
    try:
        # Early-exit to support corpuses derived from databases of pre-encoded
        # content files.
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor after splitting
        # Corpus class.
        if config.HasField("pre_encoded_corpus_url"):
            return config

        pbutil.AssertFieldIsSet(config, "contentfiles")
        pbutil.AssertFieldIsSet(config, "atomizer")
        pbutil.AssertFieldIsSet(config, "contentfile_separator")
        # Check that the preprocessor pipeline resolves to preprocessor functions.
        [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

        if config.HasField("greedy_multichar_atomizer"):
            if not config.greedy_multichar_atomizer.tokens:
                raise errors.UserError(
                    "GreedyMulticharAtomizer.tokens is empty")
            for atom in config.greedy_multichar_atomizer.tokens:
                if not atom:
                    raise errors.UserError(
                        "Empty string found in GreedyMulticharAtomizer.tokens is empty"
                    )

        return config
    except pbutil.ProtoValueError as e:
        raise errors.UserError(e)
コード例 #15
0
ファイル: corpuses.py プロジェクト: whatsmyname/clgen
def GetHashOfArchiveContents(archive: pathlib.Path) -> str:
    """Compute the checksum of the contents of a directory.

  Args:
    archive: Path of the archive.

  Returns:
    Checksum of the archive.

  Raises:
    UserError: If the requested archive does not exist, or cannot be unpacked.
  """
    if not archive.is_file():
        raise errors.UserError(f"Archive not found: '{archive}'")

    with tempfile.TemporaryDirectory(prefix="clgen_corpus_") as d:
        cmd = ["tar", "-xf", str(archive), "-C", d]
        try:
            subprocess.check_call(cmd)
        except subprocess.CalledProcessError:
            raise errors.UserError(f"Archive unpack failed: '{archive}'")
        return checksumdir.dirhash(d, "sha1")
コード例 #16
0
ファイル: data_generators.py プロジェクト: whatsmyname/clgen
    def CreateBatches(self) -> None:
        start_time = time.time()

        # generate a kernel corpus
        self.i = 0
        if (self.encoded_corpus is None or
                self.training_opts.shuffle_corpus_contentfiles_between_epochs):
            self.encoded_corpus = self.corpus.GetTrainingData(
                shuffle=self.training_opts.
                shuffle_corpus_contentfiles_between_epochs)

        batch_size = self.training_opts.batch_size
        sequence_length = self.training_opts.sequence_length

        # set corpus size and number of batches
        self.num_batches = int(
            len(self.encoded_corpus) / (batch_size * sequence_length))
        if self.num_batches == 0:
            raise errors.UserError(
                "Not enough data. Use a smaller sequence_length and batch_size"
            )

        # split into batches
        clipped_corpus_length = self.num_batches * batch_size * sequence_length
        clipped_corpus = self.encoded_corpus[:clipped_corpus_length]
        xdata = clipped_corpus
        ydata = np.copy(clipped_corpus)

        # Wrap-around.
        ydata[:-1] = xdata[1:]
        ydata[-1] = xdata[0]
        self.batches = [
            DataBatch(x, y) for x, y in zip(
                np.split(xdata.reshape(batch_size, -1), self.num_batches, 1),
                np.split(ydata.reshape(batch_size, -1), self.num_batches, 1),
            )
        ]
        app.Log(
            1,
            "Encoded corpus of %s tokens (clipped last %s tokens) in %s ms.",
            humanize.Commas(clipped_corpus_length),
            humanize.Commas(len(self.encoded_corpus) - clipped_corpus_length),
            humanize.Commas(int((time.time() - start_time) * 1000)),
        )
コード例 #17
0
    def FromText(cls, text: str, atoms: typing.Set[str]) -> "GreedyAtomizer":
        """Instantiate and an atomizer from a corpus text.

    Args:
      text: Text corpus
      atoms: A list of multi-character tokens.

    Returns:
      An atomizer instance.
    """
        if not atoms:
            raise errors.UserError("No atoms specified")

        # Instantiate a greedy atomizer using the full vocabulary.
        full_vocab = dict(zip(atoms, range(len(atoms))))
        c = GreedyAtomizer(full_vocab, determine_chars=True)
        # Derive the subset of the vocabulary required to encode the given text.
        tokens = sorted(list(set(c.TokenizeString(text))))
        vocab_subset = dict(zip(tokens, range(len(tokens))))
        end_time = labdate.MillisecondsTimestamp()
        # Return a new atomizer using the subset vocabulary.
        return GreedyAtomizer(vocab_subset)
コード例 #18
0
ファイル: models.py プロジェクト: zhangheyu518/clgen
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")
        # Validate config options.
        if config.training.sequence_length < 1:
            raise errors.UserError(
                'TrainingOptions.sequence_length must be >= 1')

        self.config = model_pb2.Model()
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        self.corpus = corpuses.Corpus(config.corpus)
        self.hash = self._ComputeHash(self.corpus, self.config)
        self.cache = cache.mkcache('model', self.hash)
        # Create the necessary cache directories.
        (self.cache.path / 'checkpoints').mkdir(exist_ok=True)
        (self.cache.path / 'samples').mkdir(exist_ok=True)
        (self.cache.path / 'logs').mkdir(exist_ok=True)

        # Create symlink to encoded corpus.
        symlink = self.cache.path / 'corpus'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(
                    pathlib.Path(
                        self.corpus.encoded.url[len('sqlite:///'):]).parent,
                    self.cache.path), symlink)

        # Create symlink to the atomizer.
        symlink = self.cache.path / 'atomizer'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(self.corpus.atomizer_path, self.cache.path),
                symlink)

        # Validate metadata against cache.
        if self.cache.get('META.pbtxt'):
            cached_meta = pbutil.FromFile(
                pathlib.Path(self.cache['META.pbtxt']),
                internal_pb2.ModelMeta())
            # Exclude num_epochs and corpus location from metadata comparison.
            config_to_compare = model_pb2.Model()
            config_to_compare.CopyFrom(self.config)
            config_to_compare.corpus.ClearField('contentfiles')
            config_to_compare.training.ClearField('num_epochs')
            # These fields should have already been cleared, but we'll do it again
            # so that metadata comparisons don't fail when the cached meta schema
            # is updated.
            cached_to_compare = model_pb2.Model()
            cached_to_compare.CopyFrom(cached_meta.config)
            cached_to_compare.corpus.ClearField('contentfiles')
            cached_to_compare.training.ClearField('num_epochs')
            if config_to_compare != cached_to_compare:
                raise errors.InternalError('Metadata mismatch')
            self.meta = cached_meta
        else:
            self.meta = internal_pb2.ModelMeta()
            self.meta.config.CopyFrom(self.config)
            self._WriteMetafile()

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW:
            tensorflow_backend.TensorFlowBackend,
            model_pb2.NetworkArchitecture.KERAS: keras_backend.KerasBackend,
        }[config.architecture.backend](self.config, self.cache, self.corpus)
コード例 #19
0
ファイル: corpuses.py プロジェクト: whatsmyname/clgen
def ResolveContentId(config: corpus_pb2.Corpus,
                     hc: typing.Optional[hashcache.HashCache] = None) -> str:
    """Compute the hash of the input contentfiles.

  This function resolves the unique sha1 checksum of a set of content files.

  Args:
    config: The corpus config proto.
    hc: A hashcache database instance, used for resolving directory hashes. If
      the corpus has pre_encoded_corpus_url field set, this may be omitted.

  Returns:
    A hex encoded sha1 string.
  """
    # We can take a massive shortcut if the content ID is already set in the
    # config proto.
    if config.HasField("content_id"):
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this after splitting
        # out Corpus class.
        return config.content_id
    elif config.HasField("pre_encoded_corpus_url"):
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this after splitting
        # out Corpus class.
        return crypto.sha1_str(config.pre_encoded_corpus_url)

    start_time = time.time()
    if config.HasField("local_directory"):
        local_directory = ExpandConfigPath(
            config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix)

        # After the first time we compute the hash of a directory, we write it into
        # a file. This is a shortcut to work around the fact that computing the
        # directory checksum is O(n) with respect to the number of files in the
        # directory (even if the directory is already cached by the hash cache).
        # This means that it is the responsibility of the user to delete this cached
        # file if the directory is changed.
        hash_file_path = pathlib.Path(str(local_directory) + ".sha1.txt")
        if hash_file_path.is_file():
            app.Log(1, "Reading directory hash: '%s'.", hash_file_path)
            with open(hash_file_path) as f:
                content_id = f.read().rstrip()
        else:
            # No hash file, so compute the directory hash and create it.
            try:
                content_id = hc.GetHash(local_directory)
            except FileNotFoundError as e:
                raise errors.UserError(e)
            # Create the hash file in the directory so that next time we don't need
            # to reference the hash cache.
            with open(hash_file_path, "w") as f:
                print(content_id, file=f)
            app.Log(1, "Wrote directory hash: '%s'.", hash_file_path)
    elif config.HasField("local_tar_archive"):
        # This if not an efficient means of getting the hash, as it requires always
        # unpacking the archive and reading the entire contents. It would be nicer
        # to maintain a cache which maps the mtime of tarballs to their content ID,
        # similart to how local_directory is implemented.
        content_id = GetHashOfArchiveContents(
            ExpandConfigPath(config.local_tar_archive,
                             path_prefix=FLAGS.clgen_local_path_prefix))
    else:
        raise NotImplementedError(
            "Unsupported Corpus.contentfiles field value")
    app.Log(
        2,
        "Resolved Content ID %s in %s ms.",
        content_id,
        humanize.Commas(int((time.time() - start_time) * 1000)),
    )
    return content_id
コード例 #20
0
ファイル: corpuses.py プロジェクト: whatsmyname/clgen
    def __init__(self, config: corpus_pb2.Corpus):
        """Instantiate a corpus from a proto config.

    If this is a new corpus, a number of files will be created, which may
    take some time.

    Args:
      config: A Corpus message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: In case the corpus is not found, or config contains invalid
        options.
      EmptyCorpusException: In case the corpus contains no data.
    """
        if not isinstance(config, corpus_pb2.Corpus):
            t = type(config).__name__
            raise TypeError(f"Config must be a Corpus proto. Received: '{t}'")

        # Make a local copy of the configuration.
        self.config = corpus_pb2.Corpus()
        self.config.CopyFrom(AssertConfigIsValid(config))
        self._atomizer = None
        self._created = False
        self.dashboard_db = dashboard_db.GetDatabase()
        self._dashboard_db_id: typing.Optional[int] = None  # Set in Create()

        # An in-memory cache of the encoded contentfiles indices arrays.
        # Set and used in GetTrainingData().
        self._indices_arrays: typing.Optional[typing.List[np.array]] = None

        cache.cachepath("corpus").mkdir(parents=True, exist_ok=True)
        hc = hashcache.HashCache(cache.cachepath("hashcache.db"), "sha1")
        self.content_id = ResolveContentId(self.config, hc)
        # Database of pre-processed files.
        preprocessed_id = ResolvePreprocessedId(self.content_id, self.config)
        cache.cachepath("corpus", "preprocessed",
                        preprocessed_id).mkdir(exist_ok=True, parents=True)
        preprocessed_db_path = cache.cachepath("corpus", "preprocessed",
                                               preprocessed_id,
                                               "preprocessed.db")
        if (self.config.HasField("content_id")
                and not preprocessed_db_path.is_file()):
            raise errors.UserError(
                f"Content ID not found: '{self.content_id}'")
        self.preprocessed = preprocessed.PreprocessedContentFiles(
            f"sqlite:///{preprocessed_db_path}")
        # Create symlink to contentfiles.
        symlink = (
            pathlib.Path(self.preprocessed.url[len("sqlite:///"):]).parent /
            "contentfiles")
        if not symlink.is_symlink():
            if config.HasField("local_directory"):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_directory,
                            path_prefix=FLAGS.clgen_local_path_prefix)),
                    symlink,
                )
            elif config.HasField("local_tar_archive"):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_tar_archive,
                            path_prefix=FLAGS.clgen_local_path_prefix,
                        )),
                    symlink,
                )
        # Data of encoded pre-preprocessed files.
        encoded_id = ResolveEncodedId(self.content_id, self.config)
        cache.cachepath("corpus", "encoded", encoded_id).mkdir(exist_ok=True,
                                                               parents=True)
        db_path = cache.cachepath("corpus", "encoded", encoded_id,
                                  "encoded.db")
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this conditional
        # logic by making Corpus an abstract class and creating concrete subclasses
        # for the different types of corpus.
        if self.config.HasField("pre_encoded_corpus_url"):
            self.encoded = encoded.EncodedContentFiles(
                config.pre_encoded_corpus_url)
        else:
            self.encoded = encoded.EncodedContentFiles(f"sqlite:///{db_path}")
        self.atomizer_path = cache.cachepath("corpus", "encoded", encoded_id,
                                             "atomizer.pkl")
        # Create symlink to preprocessed files.
        # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this conditional
        # logic after splitting Corpus class.
        if not self.config.HasField("pre_encoded_corpus_url"):
            symlink = (
                pathlib.Path(self.encoded.url[len("sqlite:///"):]).parent /
                "preprocessed")
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(
                        pathlib.Path(
                            self.preprocessed.url[len("sqlite:///"):]).parent,
                        pathlib.Path(
                            self.encoded.url[len("sqlite:///"):]).parent,
                    ),
                    symlink,
                )
        self.hash = encoded_id
        self.cache = cache.mkcache("corpus", "encoded", encoded_id)
コード例 #21
0
    def __init__(self, config: corpus_pb2.Corpus):
        """Instantiate a corpus from a proto config.

    If this is a new corpus, a number of files will be created, which may
    take some time.

    Args:
      config: A Corpus message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: In case the corpus is not found, or config contains invalid
        options.
      EmptyCorpusException: In case the corpus contains no data.
    """
        if not isinstance(config, corpus_pb2.Corpus):
            t = type(config).__name__
            raise TypeError(f"Config must be a Corpus proto. Received: '{t}'")

        # Make a local copy of the configuration.
        self.config = corpus_pb2.Corpus()
        self.config.CopyFrom(AssertConfigIsValid(config))
        self._atomizer = None
        self._created = False

        cache.cachepath('corpus').mkdir(parents=True, exist_ok=True)
        hc = hashcache.HashCache(cache.cachepath('hashcache.db'), 'sha1')
        self.content_id = ResolveContentId(self.config, hc)
        # Database of pre-processed files.
        preprocessed_id = ResolvePreprocessedId(self.content_id, self.config)
        cache.cachepath('corpus', 'preprocessed',
                        preprocessed_id).mkdir(exist_ok=True, parents=True)
        preprocessed_db_path = cache.cachepath('corpus', 'preprocessed',
                                               preprocessed_id,
                                               'preprocessed.db')
        if (self.config.HasField('content_id')
                and not preprocessed_db_path.is_file()):
            raise errors.UserError(
                f"Content ID not found: '{self.content_id}'")
        self.preprocessed = preprocessed.PreprocessedContentFiles(
            preprocessed_db_path)
        # Create symlink to contentfiles.
        symlink = self.preprocessed.database_path.parent / 'contentfiles'
        if not symlink.is_symlink():
            if config.HasField('local_directory'):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_directory,
                            path_prefix=FLAGS.clgen_local_path_prefix)),
                    symlink)
            elif config.HasField('local_tar_archive'):
                os.symlink(
                    str(
                        ExpandConfigPath(
                            config.local_tar_archive,
                            path_prefix=FLAGS.clgen_local_path_prefix)),
                    symlink)
        # Data of encoded pre-preprocessed files.
        encoded_id = ResolveEncodedId(self.content_id, self.config)
        cache.cachepath('corpus', 'encoded', encoded_id).mkdir(exist_ok=True,
                                                               parents=True)
        self.encoded = encoded.EncodedContentFiles(
            cache.cachepath('corpus', 'encoded', encoded_id, 'encoded.db'))
        self.atomizer_path = cache.cachepath('corpus', 'encoded', encoded_id,
                                             'atomizer.pkl')
        # Create symlink to preprocessed files.
        symlink = self.encoded.database_path.parent / 'preprocessed'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(self.preprocessed.database_path.parent,
                                self.encoded.database_path.parent), symlink)
        self.hash = encoded_id
        self.cache = cache.mkcache('corpus', 'encoded', encoded_id)
コード例 #22
0
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model:
  """Assert that a model configuration is buildable.

  Args:
    config: A model proto.

  Returns:
    The input model proto, unmodified.

  Raises:
    UserError: If the model is not buildable.
    InternalError: If the value of the training.optimizer field is not
      understood.
  """
  # Any change to the Model proto schema will require a change to this function.
  try:
    pbutil.AssertFieldIsSet(config, 'corpus')
    pbutil.AssertFieldIsSet(config, 'architecture')
    pbutil.AssertFieldIsSet(config, 'training')
    pbutil.AssertFieldIsSet(config.architecture, 'backend')
    pbutil.AssertFieldIsSet(config.architecture, 'neuron_type')
    if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS:
      pbutil.AssertFieldConstraint(
          config.architecture, 'embedding_size', lambda x: 0 < x,
          'NetworkArchitecture.embedding_size must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'neurons_per_layer', lambda x: 0 < x,
        'NetworkArchitecture.neurons_per_layer must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'num_layers', lambda x: 0 < x,
        'NetworkArchitecture.num_layers must be > 0')
    pbutil.AssertFieldConstraint(
        config.architecture, 'post_layer_dropout_micros',
        lambda x: 0 <= x <= 1000000,
        'NetworkArchitecture.post_layer_dropout_micros '
        'must be >= 0 and <= 1000000')
    pbutil.AssertFieldConstraint(
        config.training, 'num_epochs', lambda x: 0 < x,
        'TrainingOptions.num_epochs must be > 0')
    pbutil.AssertFieldIsSet(
        config.training, 'shuffle_corpus_contentfiles_between_epochs')
    pbutil.AssertFieldConstraint(
        config.training, 'batch_size', lambda x: 0 < x,
        'TrainingOptions.batch_size must be > 0')
    pbutil.AssertFieldIsSet(config.training, 'optimizer')
    if config.training.HasField('adam_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'AdamOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_1_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'beta_2_micros', lambda x: 0 <= x <= 1000000,
          'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000')
      pbutil.AssertFieldConstraint(
          config.training.adam_optimizer,
          'normalized_gradient_clip_micros', lambda x: 0 <= x,
          'AdamOptimizer.normalized_gradient_clip_micros must be >= 0')
    elif config.training.HasField('rmsprop_optimizer'):
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer, 'initial_learning_rate_micros',
          lambda x: 0 <= x,
          'RmsPropOptimizer.initial_learning_rate_micros must be >= 0')
      pbutil.AssertFieldConstraint(
          config.training.rmsprop_optimizer,
          'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x,
          'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0')
    else:
      raise errors.InternalError(
          "Unrecognized value: 'TrainingOptions.optimizer'")
  except pbutil.ProtoValueError as e:
    raise errors.UserError(str(e))
  return config