def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus: """Assert that config proto is valid. Args: config: A Corpus proto. Returns: The Corpus proto. Raises: UserError: If the config is invalid. """ try: pbutil.AssertFieldIsSet(config, 'contentfiles') pbutil.AssertFieldIsSet(config, 'atomizer') pbutil.AssertFieldIsSet(config, 'contentfile_separator') # Check that the preprocessor pipeline resolves to preprocessor functions. [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor] if config.HasField('greedy_multichar_atomizer'): if not config.greedy_multichar_atomizer.tokens: raise errors.UserError( 'GreedyMulticharAtomizer.tokens is empty') for atom in config.greedy_multichar_atomizer.tokens: if not atom: raise errors.UserError( 'Empty string found in GreedyMulticharAtomizer.tokens is empty' ) return config except pbutil.ProtoValueError as e: raise errors.UserError(e)
def GetPreprocessorFunction(name: str) -> public.PreprocessorFunction: """Lookup a preprocess function by name. A preprocessor is a function which takes a single argument 'text' of type str, and returns a str. The name is the fully qualified name of the python function which implements it, in the form <module>:<name>. For example, the name 'deeplearning.clgen.preprocessors.cxx:Compile' will return the function 'Compile' in the module 'deeplearning.clgen.preprocessors.cxx'. Args: name: The name of the preprocessor to get. Returns: The python preprocessor function. Raises: UserError: If the requested name cannot be found or is not a @clgen_preprocessor decorated function. """ components = name.split(':') if len(components) != 2: raise errors.UserError(f'Invalid preprocessor name {name}') module_name, function_name = components try: module = importlib.import_module(module_name) function_ = getattr(module, function_name) except (ModuleNotFoundError, AttributeError): raise errors.UserError(f'Preprocessor {name} not found.') if not function_.__dict__.get('is_clgen_preprocessor'): raise errors.UserError( f'Preprocessor {name} not decorated with @clgen_preprocessor') return function_
def __init__(self, config: sampler_pb2.SymmetricalTokenDepth): try: self.left_token = pbutil.AssertFieldConstraint( config, 'depth_increase_token', lambda s: len(s) > 0, 'SymmetricalTokenDepth.depth_increase_token must be a string') self.right_token = pbutil.AssertFieldConstraint( config, 'depth_decrease_token', lambda s: len(s) > 0, 'SymmetricalTokenDepth.depth_decrease_token must be a string') except pbutil.ProtoValueError as e: raise errors.UserError(e) if self.left_token == self.right_token: raise errors.UserError( 'SymmetricalTokenDepth tokens must be different')
def _ImportPreprocessorFromModule(module_name: str, function_name: str): """Import module from a fully qualified module name, e.g. 'foo.bar'.""" try: module = importlib.import_module(module_name) except (ModuleNotFoundError, AttributeError): raise errors.UserError(f'Module {module_name} not found.') if not hasattr(module, function_name): raise errors.UserError( f'Function {function_name} not found in module {module_name}') function_ = getattr(module, function_name) if not function_.__dict__.get('is_clgen_preprocessor'): raise errors.UserError( f'Preprocessor {function_name} not decorated with @clgen_preprocessor' ) return function_
def _ImportPreprocessorFromFile(module_path: pathlib.Path, function_name: str): """Import module from an absolute path to file, e.g. '/foo/bar.py'.""" if not module_path.is_file(): raise errors.UserError(f"File not found: {module_path}") try: spec = importlib_util.spec_from_file_location("module", str(module_path)) module = importlib_util.module_from_spec(spec) spec.loader.exec_module(module) except ImportError as e: raise errors.UserError(f'Failed to import module {module_path}: {e}') if not hasattr(module, function_name): raise errors.UserError( f'Function {function_name} not found in module {module_path}') return getattr(module, function_name)
def GetPreprocessorFunction(name: str) -> public.PreprocessorFunction: """Lookup a preprocess function by name. A preprocessor is a function which takes a single argument 'text' of type str, and returns a str. The name is in the form <module>:<name>, where <name> is the name of a python function, and <module> is either a fully qualified module name, or an absolute path to the module file. For example, the name 'deeplearning.clgen.preprocessors.cxx:Compile' will return the function 'Compile' in the module 'deeplearning.clgen.preprocessors.cxx'. The name '/tmp/my_preprocessors.py:Transform' will return the function Transform() in the module defined at '/tmp/my_preprocessors.py'. Args: name: The name of the preprocessor to get. Returns: The python preprocessor function. Raises: UserError: If the requested name cannot be found or is not a @clgen_preprocessor decorated function. """ components = name.split(':') if len(components) != 2: raise errors.UserError(f'Invalid preprocessor name {name}') module_name, function_name = components if module_name[0] == '/': return _ImportPreprocessorFromFile(pathlib.Path(module_name), function_name) else: return _ImportPreprocessorFromModule(module_name, function_name)
def __init__(self, config: sampler_pb2.MaxTokenLength): try: self.max_len = pbutil.AssertFieldConstraint( config, 'maximum_tokens_in_sample', lambda x: x > 1, 'MaxTokenLength.maximum_tokens_in_sample must be > 0') except pbutil.ProtoValueError as e: raise errors.UserError(e)
def AssertConfigIsValid(config: sampler_pb2.Sampler) -> sampler_pb2.Sampler: """Assert that a sampler configuration contains no invalid values. Args: config: A sampler configuration proto. Returns: The sampler configuration proto. Raises: UserError: If there are configuration errors. """ try: pbutil.AssertFieldConstraint(config, 'start_text', lambda s: len(s), 'Sampler.start_text must be a string') pbutil.AssertFieldConstraint(config, 'batch_size', lambda x: 0 < x, 'Sampler.batch_size must be > 0') pbutil.AssertFieldConstraint(config, 'sequence_length', lambda x: 0 < x, 'Sampler.sequence_length must be > 0') pbutil.AssertFieldConstraint(config, 'temperature_micros', lambda x: 0 < x, 'Sampler.temperature_micros must be > 0') return config except pbutil.ProtoValueError as e: raise errors.UserError(e)
def Sample( self, sampler: samplers.Sampler, sample_observers: typing.List[sample_observers_lib.SampleObserver], seed: int = None, ) -> None: """Sample a model. This method uses the observer model, returning nothing. To access the samples produced, implement a SampleObserver and pass it in as an argument. Sampling continues indefinitely until one of the sample observers returns False when notified of a new sample. If the model is not already trained, calling Sample() first trains the model. Thus a call to Sample() is equivalent to calling Train() then Sample(). Args: sampler: The sampler to sample using. sample_observers: A list of SampleObserver objects that are notified of new generated samples. seed: A numeric value to seed the RNG with. If not present, the RNG is seeded randomly. Raises: UserError: If called with no sample observers. UnableToAcquireLockError: If the model is locked (i.e. there is another process currently modifying the model). InvalidStartText: If the sampler start text cannot be encoded. InvalidSymtokTokens: If the sampler symmetrical depth tokens cannot be encoded. """ if not sample_observers: raise errors.UserError("Cannot sample without any observers") sample_start_time = labdate.MillisecondsTimestamp() self.Train() with logutil.TeeLogsToFile(f"sampler_{sampler.hash}", self.cache.path / "logs"): app.Log(1, "Sampling: '%s'", sampler.start_text) atomizer = self.corpus.atomizer sampler.Specialize(atomizer) self.backend.InitSampling(sampler, seed) [obs.Specialize(self, sampler) for obs in sample_observers] batch_count = 1 while self._SampleBatch(sampler, atomizer, sample_observers): batch_count += 1 time_now = labdate.MillisecondsTimestamp() app.Log( 1, "Produced %s sample batches at a rate of %s ms / batch.", humanize.Commas(batch_count), humanize.Commas( int((time_now - sample_start_time) / max(batch_count, 1))), )
def __init__(self, config: clgen_pb2.Instance): """Instantiate an instance. Args: config: An Instance proto. Raises: UserError: If the instance proto contains invalid values, is missing a model or sampler fields. """ try: pbutil.AssertFieldIsSet(config, 'pretrained_model') pbutil.AssertFieldIsSet(config, 'sampler') except pbutil.ProtoValueError as e: raise errors.UserError(e) self.working_dir = None if config.HasField('working_dir'): self.working_dir: pathlib.Path = pathlib.Path( os.path.expandvars( config.working_dir)).expanduser().absolute() # Enter a session so that the cache paths are set relative to any requested # working directory. with self.Session(): self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel( pathlib.Path(config.pretrained_model)) self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
def ResolveContentId(config: corpus_pb2.Corpus, hc: hashcache.HashCache) -> str: """Compute the hash of the input contentfiles. This function resolves the unique sha1 checksum of a set of content files. Args: config: The corpus config proto. hc: A hashcache database instance, used for resolving directory hashes. Returns: A hex encoded sha1 string. """ # We can take a massive shortcut if the content ID is already set in the # config proto. if config.HasField('content_id'): return config.content_id start_time = time.time() if config.HasField('local_directory'): # After the first time we compute the hash of a directory, we write it into # a file. This is a shortcut to work around the fact that computing the # directory checksum is O(n) with respect to the number of files in the # directory (even if the directory is already cached by the hash cache). # This means that it is the responsibility of the user to delete this cached # file if the directory is changed. hash_file_path = pathlib.Path( str(pathlib.Path(config.local_directory)) + '.sha1.txt') if hash_file_path.is_file(): logging.info("Reading directory hash: '%s'.", hash_file_path) with open(hash_file_path) as f: content_id = f.read().rstrip() else: # No hash file, so compute the directory hash and create it. try: content_id = hc.GetHash( ExpandConfigPath( config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix)) except FileNotFoundError as e: raise errors.UserError(e) # Create the hash file in the directory so that next time we don't need # to reference the hash cache. with open(hash_file_path, 'w') as f: print(content_id, file=f) logging.info("Wrote directory hash: '%s'.", hash_file_path) elif config.HasField('local_tar_archive'): # This if not an efficient means of getting the hash, as it requires always # unpacking the archive and reading the entire contents. It would be nicer # to maintain a cache which maps the mtime of tarballs to their content ID, # similart to how local_directory is implemented. content_id = GetHashOfArchiveContents( ExpandConfigPath(config.local_tar_archive, path_prefix=FLAGS.clgen_local_path_prefix)) else: raise NotImplementedError( 'Unsupported Corpus.contentfiles field value') logging.debug('Resolved Content ID %s in %s ms.', content_id, humanize.intcomma(int((time.time() - start_time) * 1000))) return content_id
def __init__(self, config: clgen_pb2.Instance, dashboard_opts={}): """Instantiate an instance. Args: config: An Instance proto. Raises: UserError: If the instance proto contains invalid values, is missing a model or sampler fields. """ try: pbutil.AssertFieldIsSet(config, "model_specification") pbutil.AssertFieldIsSet(config, "sampler") except pbutil.ProtoValueError as e: raise errors.UserError(e) self.config = config self.working_dir = None if config.HasField("working_dir"): self.working_dir: pathlib.Path = pathlib.Path( os.path.expandvars( config.working_dir)).expanduser().absolute() # Enter a session so that the cache paths are set relative to any requested # working directory. with self.Session(): if config.HasField("model"): self.model: models.Model = models.Model(config.model) else: self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel( pathlib.Path(config.pretrained_model)) self.sampler: samplers.Sampler = samplers.Sampler(config.sampler) self.dashboard = dashboard.Launch(**dashboard_opts)
def GetTrainingCorpus( corpus: "corpuses.Corpus", training_opts: model_pb2.TrainingOptions ) -> typing.Tuple[np.ndarray, np.ndarray, int]: """Get the corpus to train over. Args: corpus: A Corpus instance. training_opts: A TrainingOptions proto. Returns: An X, y pair of data for an epoch, and the number of steps in the epoch. Raises: UserError: If batch_size and sequence_length are too large for the corpus, yielding no batches. """ start_time = time.time() encoded_corpus = corpus.GetTrainingData( shuffle=training_opts.shuffle_corpus_contentfiles_between_epochs) corpus_length = len(encoded_corpus) steps_per_epoch = (corpus_length - 1) // (training_opts.batch_size * training_opts.sequence_length) if not steps_per_epoch: raise errors.UserError( f"Requested batch size ({training_opts.batch_size}) and " f"sequence length ({training_opts.sequence_length}) are too large for " f"corpus of size {corpus_length}.") clipped_corpus_length = (steps_per_epoch * training_opts.batch_size * training_opts.sequence_length) x = np.reshape( encoded_corpus[:clipped_corpus_length], [ training_opts.batch_size, steps_per_epoch * training_opts.sequence_length ], ) y = np.reshape( encoded_corpus[1:clipped_corpus_length + 1], [ training_opts.batch_size, steps_per_epoch * training_opts.sequence_length ], ) app.Log( 1, "Encoded corpus of %s tokens (clipped last %s tokens) in %s ms.", humanize.Commas(clipped_corpus_length), humanize.Commas(corpus_length - clipped_corpus_length), humanize.Commas(int((time.time() - start_time) * 1000)), ) return x, y, steps_per_epoch
def AssertConfigIsValid(config: corpus_pb2.Corpus) -> corpus_pb2.Corpus: """Assert that config proto is valid. Args: config: A Corpus proto. Returns: The Corpus proto. Raises: UserError: If the config is invalid. """ try: # Early-exit to support corpuses derived from databases of pre-encoded # content files. # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor after splitting # Corpus class. if config.HasField("pre_encoded_corpus_url"): return config pbutil.AssertFieldIsSet(config, "contentfiles") pbutil.AssertFieldIsSet(config, "atomizer") pbutil.AssertFieldIsSet(config, "contentfile_separator") # Check that the preprocessor pipeline resolves to preprocessor functions. [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor] if config.HasField("greedy_multichar_atomizer"): if not config.greedy_multichar_atomizer.tokens: raise errors.UserError( "GreedyMulticharAtomizer.tokens is empty") for atom in config.greedy_multichar_atomizer.tokens: if not atom: raise errors.UserError( "Empty string found in GreedyMulticharAtomizer.tokens is empty" ) return config except pbutil.ProtoValueError as e: raise errors.UserError(e)
def GetHashOfArchiveContents(archive: pathlib.Path) -> str: """Compute the checksum of the contents of a directory. Args: archive: Path of the archive. Returns: Checksum of the archive. Raises: UserError: If the requested archive does not exist, or cannot be unpacked. """ if not archive.is_file(): raise errors.UserError(f"Archive not found: '{archive}'") with tempfile.TemporaryDirectory(prefix="clgen_corpus_") as d: cmd = ["tar", "-xf", str(archive), "-C", d] try: subprocess.check_call(cmd) except subprocess.CalledProcessError: raise errors.UserError(f"Archive unpack failed: '{archive}'") return checksumdir.dirhash(d, "sha1")
def CreateBatches(self) -> None: start_time = time.time() # generate a kernel corpus self.i = 0 if (self.encoded_corpus is None or self.training_opts.shuffle_corpus_contentfiles_between_epochs): self.encoded_corpus = self.corpus.GetTrainingData( shuffle=self.training_opts. shuffle_corpus_contentfiles_between_epochs) batch_size = self.training_opts.batch_size sequence_length = self.training_opts.sequence_length # set corpus size and number of batches self.num_batches = int( len(self.encoded_corpus) / (batch_size * sequence_length)) if self.num_batches == 0: raise errors.UserError( "Not enough data. Use a smaller sequence_length and batch_size" ) # split into batches clipped_corpus_length = self.num_batches * batch_size * sequence_length clipped_corpus = self.encoded_corpus[:clipped_corpus_length] xdata = clipped_corpus ydata = np.copy(clipped_corpus) # Wrap-around. ydata[:-1] = xdata[1:] ydata[-1] = xdata[0] self.batches = [ DataBatch(x, y) for x, y in zip( np.split(xdata.reshape(batch_size, -1), self.num_batches, 1), np.split(ydata.reshape(batch_size, -1), self.num_batches, 1), ) ] app.Log( 1, "Encoded corpus of %s tokens (clipped last %s tokens) in %s ms.", humanize.Commas(clipped_corpus_length), humanize.Commas(len(self.encoded_corpus) - clipped_corpus_length), humanize.Commas(int((time.time() - start_time) * 1000)), )
def FromText(cls, text: str, atoms: typing.Set[str]) -> "GreedyAtomizer": """Instantiate and an atomizer from a corpus text. Args: text: Text corpus atoms: A list of multi-character tokens. Returns: An atomizer instance. """ if not atoms: raise errors.UserError("No atoms specified") # Instantiate a greedy atomizer using the full vocabulary. full_vocab = dict(zip(atoms, range(len(atoms)))) c = GreedyAtomizer(full_vocab, determine_chars=True) # Derive the subset of the vocabulary required to encode the given text. tokens = sorted(list(set(c.TokenizeString(text)))) vocab_subset = dict(zip(tokens, range(len(tokens)))) end_time = labdate.MillisecondsTimestamp() # Return a new atomizer using the subset vocabulary. return GreedyAtomizer(vocab_subset)
def __init__(self, config: model_pb2.Model): """Instantiate a model. Args: config: A Model message. Raises: TypeError: If the config argument is not a Model proto. UserError: In case on an invalid config. """ # Error early, so that a cache isn't created. if not isinstance(config, model_pb2.Model): t = type(config).__name__ raise TypeError(f"Config must be a Model proto. Received: '{t}'") # Validate config options. if config.training.sequence_length < 1: raise errors.UserError( 'TrainingOptions.sequence_length must be >= 1') self.config = model_pb2.Model() self.config.CopyFrom(builders.AssertIsBuildable(config)) self.corpus = corpuses.Corpus(config.corpus) self.hash = self._ComputeHash(self.corpus, self.config) self.cache = cache.mkcache('model', self.hash) # Create the necessary cache directories. (self.cache.path / 'checkpoints').mkdir(exist_ok=True) (self.cache.path / 'samples').mkdir(exist_ok=True) (self.cache.path / 'logs').mkdir(exist_ok=True) # Create symlink to encoded corpus. symlink = self.cache.path / 'corpus' if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path( self.corpus.encoded.url[len('sqlite:///'):]).parent, self.cache.path), symlink) # Create symlink to the atomizer. symlink = self.cache.path / 'atomizer' if not symlink.is_symlink(): os.symlink( os.path.relpath(self.corpus.atomizer_path, self.cache.path), symlink) # Validate metadata against cache. if self.cache.get('META.pbtxt'): cached_meta = pbutil.FromFile( pathlib.Path(self.cache['META.pbtxt']), internal_pb2.ModelMeta()) # Exclude num_epochs and corpus location from metadata comparison. config_to_compare = model_pb2.Model() config_to_compare.CopyFrom(self.config) config_to_compare.corpus.ClearField('contentfiles') config_to_compare.training.ClearField('num_epochs') # These fields should have already been cleared, but we'll do it again # so that metadata comparisons don't fail when the cached meta schema # is updated. cached_to_compare = model_pb2.Model() cached_to_compare.CopyFrom(cached_meta.config) cached_to_compare.corpus.ClearField('contentfiles') cached_to_compare.training.ClearField('num_epochs') if config_to_compare != cached_to_compare: raise errors.InternalError('Metadata mismatch') self.meta = cached_meta else: self.meta = internal_pb2.ModelMeta() self.meta.config.CopyFrom(self.config) self._WriteMetafile() self.backend = { model_pb2.NetworkArchitecture.TENSORFLOW: tensorflow_backend.TensorFlowBackend, model_pb2.NetworkArchitecture.KERAS: keras_backend.KerasBackend, }[config.architecture.backend](self.config, self.cache, self.corpus)
def ResolveContentId(config: corpus_pb2.Corpus, hc: typing.Optional[hashcache.HashCache] = None) -> str: """Compute the hash of the input contentfiles. This function resolves the unique sha1 checksum of a set of content files. Args: config: The corpus config proto. hc: A hashcache database instance, used for resolving directory hashes. If the corpus has pre_encoded_corpus_url field set, this may be omitted. Returns: A hex encoded sha1 string. """ # We can take a massive shortcut if the content ID is already set in the # config proto. if config.HasField("content_id"): # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this after splitting # out Corpus class. return config.content_id elif config.HasField("pre_encoded_corpus_url"): # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this after splitting # out Corpus class. return crypto.sha1_str(config.pre_encoded_corpus_url) start_time = time.time() if config.HasField("local_directory"): local_directory = ExpandConfigPath( config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix) # After the first time we compute the hash of a directory, we write it into # a file. This is a shortcut to work around the fact that computing the # directory checksum is O(n) with respect to the number of files in the # directory (even if the directory is already cached by the hash cache). # This means that it is the responsibility of the user to delete this cached # file if the directory is changed. hash_file_path = pathlib.Path(str(local_directory) + ".sha1.txt") if hash_file_path.is_file(): app.Log(1, "Reading directory hash: '%s'.", hash_file_path) with open(hash_file_path) as f: content_id = f.read().rstrip() else: # No hash file, so compute the directory hash and create it. try: content_id = hc.GetHash(local_directory) except FileNotFoundError as e: raise errors.UserError(e) # Create the hash file in the directory so that next time we don't need # to reference the hash cache. with open(hash_file_path, "w") as f: print(content_id, file=f) app.Log(1, "Wrote directory hash: '%s'.", hash_file_path) elif config.HasField("local_tar_archive"): # This if not an efficient means of getting the hash, as it requires always # unpacking the archive and reading the entire contents. It would be nicer # to maintain a cache which maps the mtime of tarballs to their content ID, # similart to how local_directory is implemented. content_id = GetHashOfArchiveContents( ExpandConfigPath(config.local_tar_archive, path_prefix=FLAGS.clgen_local_path_prefix)) else: raise NotImplementedError( "Unsupported Corpus.contentfiles field value") app.Log( 2, "Resolved Content ID %s in %s ms.", content_id, humanize.Commas(int((time.time() - start_time) * 1000)), ) return content_id
def __init__(self, config: corpus_pb2.Corpus): """Instantiate a corpus from a proto config. If this is a new corpus, a number of files will be created, which may take some time. Args: config: A Corpus message. Raises: TypeError: If the config argument is not a Sampler proto. UserError: In case the corpus is not found, or config contains invalid options. EmptyCorpusException: In case the corpus contains no data. """ if not isinstance(config, corpus_pb2.Corpus): t = type(config).__name__ raise TypeError(f"Config must be a Corpus proto. Received: '{t}'") # Make a local copy of the configuration. self.config = corpus_pb2.Corpus() self.config.CopyFrom(AssertConfigIsValid(config)) self._atomizer = None self._created = False self.dashboard_db = dashboard_db.GetDatabase() self._dashboard_db_id: typing.Optional[int] = None # Set in Create() # An in-memory cache of the encoded contentfiles indices arrays. # Set and used in GetTrainingData(). self._indices_arrays: typing.Optional[typing.List[np.array]] = None cache.cachepath("corpus").mkdir(parents=True, exist_ok=True) hc = hashcache.HashCache(cache.cachepath("hashcache.db"), "sha1") self.content_id = ResolveContentId(self.config, hc) # Database of pre-processed files. preprocessed_id = ResolvePreprocessedId(self.content_id, self.config) cache.cachepath("corpus", "preprocessed", preprocessed_id).mkdir(exist_ok=True, parents=True) preprocessed_db_path = cache.cachepath("corpus", "preprocessed", preprocessed_id, "preprocessed.db") if (self.config.HasField("content_id") and not preprocessed_db_path.is_file()): raise errors.UserError( f"Content ID not found: '{self.content_id}'") self.preprocessed = preprocessed.PreprocessedContentFiles( f"sqlite:///{preprocessed_db_path}") # Create symlink to contentfiles. symlink = ( pathlib.Path(self.preprocessed.url[len("sqlite:///"):]).parent / "contentfiles") if not symlink.is_symlink(): if config.HasField("local_directory"): os.symlink( str( ExpandConfigPath( config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix)), symlink, ) elif config.HasField("local_tar_archive"): os.symlink( str( ExpandConfigPath( config.local_tar_archive, path_prefix=FLAGS.clgen_local_path_prefix, )), symlink, ) # Data of encoded pre-preprocessed files. encoded_id = ResolveEncodedId(self.content_id, self.config) cache.cachepath("corpus", "encoded", encoded_id).mkdir(exist_ok=True, parents=True) db_path = cache.cachepath("corpus", "encoded", encoded_id, "encoded.db") # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this conditional # logic by making Corpus an abstract class and creating concrete subclasses # for the different types of corpus. if self.config.HasField("pre_encoded_corpus_url"): self.encoded = encoded.EncodedContentFiles( config.pre_encoded_corpus_url) else: self.encoded = encoded.EncodedContentFiles(f"sqlite:///{db_path}") self.atomizer_path = cache.cachepath("corpus", "encoded", encoded_id, "atomizer.pkl") # Create symlink to preprocessed files. # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor this conditional # logic after splitting Corpus class. if not self.config.HasField("pre_encoded_corpus_url"): symlink = ( pathlib.Path(self.encoded.url[len("sqlite:///"):]).parent / "preprocessed") if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path( self.preprocessed.url[len("sqlite:///"):]).parent, pathlib.Path( self.encoded.url[len("sqlite:///"):]).parent, ), symlink, ) self.hash = encoded_id self.cache = cache.mkcache("corpus", "encoded", encoded_id)
def __init__(self, config: corpus_pb2.Corpus): """Instantiate a corpus from a proto config. If this is a new corpus, a number of files will be created, which may take some time. Args: config: A Corpus message. Raises: TypeError: If the config argument is not a Sampler proto. UserError: In case the corpus is not found, or config contains invalid options. EmptyCorpusException: In case the corpus contains no data. """ if not isinstance(config, corpus_pb2.Corpus): t = type(config).__name__ raise TypeError(f"Config must be a Corpus proto. Received: '{t}'") # Make a local copy of the configuration. self.config = corpus_pb2.Corpus() self.config.CopyFrom(AssertConfigIsValid(config)) self._atomizer = None self._created = False cache.cachepath('corpus').mkdir(parents=True, exist_ok=True) hc = hashcache.HashCache(cache.cachepath('hashcache.db'), 'sha1') self.content_id = ResolveContentId(self.config, hc) # Database of pre-processed files. preprocessed_id = ResolvePreprocessedId(self.content_id, self.config) cache.cachepath('corpus', 'preprocessed', preprocessed_id).mkdir(exist_ok=True, parents=True) preprocessed_db_path = cache.cachepath('corpus', 'preprocessed', preprocessed_id, 'preprocessed.db') if (self.config.HasField('content_id') and not preprocessed_db_path.is_file()): raise errors.UserError( f"Content ID not found: '{self.content_id}'") self.preprocessed = preprocessed.PreprocessedContentFiles( preprocessed_db_path) # Create symlink to contentfiles. symlink = self.preprocessed.database_path.parent / 'contentfiles' if not symlink.is_symlink(): if config.HasField('local_directory'): os.symlink( str( ExpandConfigPath( config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix)), symlink) elif config.HasField('local_tar_archive'): os.symlink( str( ExpandConfigPath( config.local_tar_archive, path_prefix=FLAGS.clgen_local_path_prefix)), symlink) # Data of encoded pre-preprocessed files. encoded_id = ResolveEncodedId(self.content_id, self.config) cache.cachepath('corpus', 'encoded', encoded_id).mkdir(exist_ok=True, parents=True) self.encoded = encoded.EncodedContentFiles( cache.cachepath('corpus', 'encoded', encoded_id, 'encoded.db')) self.atomizer_path = cache.cachepath('corpus', 'encoded', encoded_id, 'atomizer.pkl') # Create symlink to preprocessed files. symlink = self.encoded.database_path.parent / 'preprocessed' if not symlink.is_symlink(): os.symlink( os.path.relpath(self.preprocessed.database_path.parent, self.encoded.database_path.parent), symlink) self.hash = encoded_id self.cache = cache.mkcache('corpus', 'encoded', encoded_id)
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model: """Assert that a model configuration is buildable. Args: config: A model proto. Returns: The input model proto, unmodified. Raises: UserError: If the model is not buildable. InternalError: If the value of the training.optimizer field is not understood. """ # Any change to the Model proto schema will require a change to this function. try: pbutil.AssertFieldIsSet(config, 'corpus') pbutil.AssertFieldIsSet(config, 'architecture') pbutil.AssertFieldIsSet(config, 'training') pbutil.AssertFieldIsSet(config.architecture, 'backend') pbutil.AssertFieldIsSet(config.architecture, 'neuron_type') if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS: pbutil.AssertFieldConstraint( config.architecture, 'embedding_size', lambda x: 0 < x, 'NetworkArchitecture.embedding_size must be > 0') pbutil.AssertFieldConstraint( config.architecture, 'neurons_per_layer', lambda x: 0 < x, 'NetworkArchitecture.neurons_per_layer must be > 0') pbutil.AssertFieldConstraint( config.architecture, 'num_layers', lambda x: 0 < x, 'NetworkArchitecture.num_layers must be > 0') pbutil.AssertFieldConstraint( config.architecture, 'post_layer_dropout_micros', lambda x: 0 <= x <= 1000000, 'NetworkArchitecture.post_layer_dropout_micros ' 'must be >= 0 and <= 1000000') pbutil.AssertFieldConstraint( config.training, 'num_epochs', lambda x: 0 < x, 'TrainingOptions.num_epochs must be > 0') pbutil.AssertFieldIsSet( config.training, 'shuffle_corpus_contentfiles_between_epochs') pbutil.AssertFieldConstraint( config.training, 'batch_size', lambda x: 0 < x, 'TrainingOptions.batch_size must be > 0') pbutil.AssertFieldIsSet(config.training, 'optimizer') if config.training.HasField('adam_optimizer'): pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'initial_learning_rate_micros', lambda x: 0 <= x, 'AdamOptimizer.initial_learning_rate_micros must be >= 0') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x, 'AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'beta_1_micros', lambda x: 0 <= x <= 1000000, 'AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'beta_2_micros', lambda x: 0 <= x <= 1000000, 'AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000') pbutil.AssertFieldConstraint( config.training.adam_optimizer, 'normalized_gradient_clip_micros', lambda x: 0 <= x, 'AdamOptimizer.normalized_gradient_clip_micros must be >= 0') elif config.training.HasField('rmsprop_optimizer'): pbutil.AssertFieldConstraint( config.training.rmsprop_optimizer, 'initial_learning_rate_micros', lambda x: 0 <= x, 'RmsPropOptimizer.initial_learning_rate_micros must be >= 0') pbutil.AssertFieldConstraint( config.training.rmsprop_optimizer, 'learning_rate_decay_per_epoch_micros', lambda x: 0 <= x, 'RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0') else: raise errors.InternalError( "Unrecognized value: 'TrainingOptions.optimizer'") except pbutil.ProtoValueError as e: raise errors.UserError(str(e)) return config