def Import(self, session: sqlutil.Session, config: corpus_pb2.Corpus) -> None: with self.GetContentFileRoot(config) as contentfile_root: relpaths = set(self.GetImportRelpaths(contentfile_root)) done = set( [x[0] for x in session.query(PreprocessedContentFile.input_relpath)]) todo = relpaths - done app.Log(1, 'Preprocessing %s of %s content files', humanize.Commas(len(todo)), humanize.Commas(len(relpaths))) jobs = [ internal_pb2.PreprocessorWorker( contentfile_root=str(contentfile_root), relpath=t, preprocessors=config.preprocessor) for t in todo ] pool = multiprocessing.Pool() bar = progressbar.ProgressBar(max_value=len(jobs)) last_commit = time.time() wall_time_start = time.time() for preprocessed_cf in bar(pool.imap_unordered(PreprocessorWorker, jobs)): wall_time_end = time.time() preprocessed_cf.wall_time_ms = (int( (wall_time_end - wall_time_start) * 1000)) wall_time_start = wall_time_end session.add(preprocessed_cf) if wall_time_end - last_commit > 10: session.commit() last_commit = wall_time_end
def Create(self) -> None: """Create the corpus files. Raises: EmptyCorpusException: If there are no content files, or no successfully pre-processed files. """ self._created = True app.Log(1, 'Content ID: %s', self.content_id) # Nothing to do for already-encoded databases. # TODO(github.com/ChrisCummins/phd/issues/46): Refactor this after splitting # out Corpus class. if self.config.HasField('pre_encoded_corpus_url'): return preprocessed_lock_path = pathlib.Path( self.preprocessed.url[len('sqlite:///'):]).parent / 'LOCK' with lockfile.LockFile(preprocessed_lock_path): self.preprocessed.Create(self.config) if not self.preprocessed.size: raise errors.EmptyCorpusException( f"Pre-processed corpus contains no files: '{self.preprocessed.url}'") encoded_lock_path = pathlib.Path( self.encoded.url[len('sqlite:///'):]).parent / 'LOCK' with lockfile.LockFile(encoded_lock_path): start_time = time.time() atomizer = self.atomizer app.Log(1, '%s: %s tokens in %s ms', type(atomizer).__name__, humanize.Commas(atomizer.vocab_size), humanize.Commas(int((time.time() - start_time) * 1000))) self.encoded.Create(self.preprocessed, atomizer, self.config.contentfile_separator)
def Create(self, p: preprocessed.PreprocessedContentFiles, atomizer: atomizers.AtomizerBase, contentfile_separator: str) -> bool: """Populate the encoded contentfiles database. Args: p: A PreprocessedContentFiles database. atomizer: An AtomizerBase instance. contentfile_separator: The contentfile separator. Returns: True if work was done, else False. Raises: EmptyCorpusException: If the PreprocessedContentFiles database has no files. """ with self.Session() as session: if not self.IsDone(session): self.Import(session, p, atomizer, contentfile_separator) self.SetDone(session) session.commit() # Logging output. num_files = session.query(EncodedContentFile).count() token_count, total_walltime, total_time, = session.query( func.sum(EncodedContentFile.tokencount), func.sum(EncodedContentFile.wall_time_ms), func.sum(EncodedContentFile.encoding_time_ms), ).first() app.Log(1, 'Encoded %s files in %s ms (%.2fx speedup).', humanize.Commas(num_files), humanize.Commas(total_walltime), total_time / total_walltime) app.Log(1, 'Encoded corpus: %s tokens, %s files.', humanize.Commas(token_count), humanize.Commas(num_files))
def Sample(self, sampler: samplers.Sampler, sample_observers: typing.List[ sample_observers_lib.SampleObserver], seed: int = None) -> None: """Sample a model. This method uses the observer model, returning nothing. To access the samples produced, implement a SampleObserver and pass it in as an argument. Sampling continues indefinitely until one of the sample observers returns False when notified of a new sample. If the model is not already trained, calling Sample() first trains the model. Thus a call to Sample() is equivalent to calling Train() then Sample(). Args: sampler: The sampler to sample using. sample_observers: A list of SampleObserver objects that are notified of new generated samples. seed: A numeric value to seed the RNG with. If not present, the RNG is seeded randomly. Raises: UserError: If called with no sample observers. UnableToAcquireLockError: If the model is locked (i.e. there is another process currently modifying the model). InvalidStartText: If the sampler start text cannot be encoded. InvalidSymtokTokens: If the sampler symmetrical depth tokens cannot be encoded. """ if not sample_observers: raise errors.UserError("Cannot sample without any observers") sample_start_time = labdate.MillisecondsTimestamp() self.Train() with logutil.TeeLogsToFile(f'sampler_{sampler.hash}', self.cache.path / 'logs'): app.Log(1, "Sampling: '%s'", sampler.start_text) atomizer = self.corpus.atomizer sampler.Specialize(atomizer) self.backend.InitSampling(sampler, seed) [obs.Specialize(self, sampler) for obs in sample_observers] batch_count = 1 while self._SampleBatch(sampler, atomizer, sample_observers): batch_count += 1 time_now = labdate.MillisecondsTimestamp() app.Log( 1, 'Produced %s sample batches at a rate of %s ms / batch.', humanize.Commas(batch_count), humanize.Commas( int((time_now - sample_start_time) / max(batch_count, 1))))
def Create(self, config: corpus_pb2.Corpus): with self.Session() as session: if not self.IsDone(session): self.Import(session, config) self.SetDone(session) session.commit() # Logging output. num_input_files = session.query(PreprocessedContentFile).count() num_files = session.query(PreprocessedContentFile).filter( PreprocessedContentFile.preprocessing_succeeded == True).count() input_chars, input_lines, total_walltime, total_time, = session.query( func.sum(PreprocessedContentFile.charcount), func.sum(PreprocessedContentFile.linecount), func.sum(PreprocessedContentFile.wall_time_ms), func.sum(PreprocessedContentFile.preprocess_time_ms), ).first() char_count, line_count = session.query( func.sum(PreprocessedContentFile.charcount), func.sum(PreprocessedContentFile.linecount), ).filter(PreprocessedContentFile.preprocessing_succeeded == True).first() app.Log(1, 'Content files: %s chars, %s lines, %s files.', humanize.Commas(input_chars), humanize.Commas(input_lines), humanize.Commas(num_input_files)) app.Log(1, 'Pre-processed %s files in %s (%.2fx speedup).', humanize.Commas(num_input_files), humanize.Duration((total_walltime or 0) / 1000), (total_time or 1) / (total_walltime or 1)) app.Log(1, 'Pre-processing discard rate: %.1f%% (%s files).', (1 - (num_files / max(num_input_files, 1))) * 100, humanize.Commas(num_input_files - num_files)) app.Log(1, 'Pre-processed corpus: %s chars, %s lines, %s files.', humanize.Commas(char_count), humanize.Commas(line_count), humanize.Commas(num_files))
def ExportCommitsThatTouchFiles(commits_in_order: typing.List[git.Commit], destiantion: git.Repo, files_of_interest: typing.Set[str]) -> int: """Filter and apply the commits that touch the given files of interest. The commits are applied in the order provided. """ exported_commit_count = 0 total_commit_count = humanize.Commas(len(commits_in_order)) for i, commit in enumerate(commits_in_order): app.Log(1, 'Processing commit %s of %s (%.2f%%) %s', humanize.Commas(i + 1), total_commit_count, ((i + 1) / len(commits_in_order)) * 100, commit) if MaybeExportCommitSubset(commit, destiantion, files_of_interest): exported_commit_count += 1
def ExportToRepo(self, repo: git.Repo, targets: typing.List[str], src_files: typing.List[str], extra_files: typing.List[str], file_move_mapping: typing.Dict[str, str]) -> None: """Export the requested targets to the destination directory.""" # The timestamp for the export. timestamp = datetime.datetime.utcnow() # Export the git history. app.Log(1, 'Exporting git history for %s files', humanize.Commas(len(src_files))) for file in src_files: print(file) source_tree.ExportGitHistoryForFiles(self.git_repo, repo, src_files) # Make manual adjustments. exported_workspace = bazelutil.Workspace(pathlib.Path( repo.working_tree_dir)) self.CreatePythonRequirementsFileForTargets(exported_workspace, targets) self.CopyFilesToDestination(exported_workspace, extra_files) self.MoveFilesToDestination(exported_workspace, file_move_mapping) if not repo.is_dirty(untracked_files=True): return app.Log(1, 'Creating automated subtree export commit') repo.git.add('.') author = git.Actor(name='[Git export bot]', email='/dev/null') repo.index.commit(f'Automated subtree export at {timestamp.isoformat()}', author=author, committer=author, skip_hooks=True)
def GetContentFileRoot(self, config: corpus_pb2.Corpus) -> pathlib.Path: """Get the path of the directory containing content files. If the corpus is a local directory, this simply returns the path. Otherwise, this method creates a temporary copy of the files which can be used within the scope of this context. Args: config: The corpus config proto. Returns: The path of a directory containing content files. """ if config.HasField('local_directory'): yield pathlib.Path(ExpandConfigPath(config.local_directory)) elif config.HasField('local_tar_archive'): with tempfile.TemporaryDirectory(prefix='clgen_corpus_') as d: start_time = time.time() cmd = [ 'tar', '-xf', str(ExpandConfigPath(config.local_tar_archive)), '-C', d ] subprocess.check_call(cmd) app.Log(1, 'Unpacked %s in %s ms', ExpandConfigPath(config.local_tar_archive).name, humanize.Commas(int((time.time() - start_time) * 1000))) yield pathlib.Path(d) else: raise NotImplementedError
def Import(self, session: sqlutil.Session, preprocessed_db: preprocessed.PreprocessedContentFiles, atomizer: atomizers.AtomizerBase, contentfile_separator: str) -> None: with preprocessed_db.Session() as p_session: query = p_session.query( preprocessed.PreprocessedContentFile).filter( preprocessed.PreprocessedContentFile. preprocessing_succeeded == True, ~preprocessed.PreprocessedContentFile.id.in_( session.query(EncodedContentFile.id).all())) jobs = [ internal_pb2.EncoderWorker( id=x.id, text=x.text, contentfile_separator=contentfile_separator, pickled_atomizer=pickle.dumps(atomizer)) for x in query ] if not jobs: raise errors.EmptyCorpusException( "Pre-processed corpus contains no files: " f"'{preprocessed_db.url}'") app.Log( 1, 'Encoding %s of %s preprocessed files', humanize.Commas(query.count()), humanize.Commas( p_session.query( preprocessed.PreprocessedContentFile).filter( preprocessed.PreprocessedContentFile. preprocessing_succeeded == True).count())) pool = multiprocessing.Pool() bar = progressbar.ProgressBar(max_value=len(jobs)) last_commit = time.time() wall_time_start = time.time() for encoded_cf in bar(pool.imap_unordered(EncoderWorker, jobs)): wall_time_end = time.time() # TODO(cec): Remove the if check once EncoderWorker no longer returns # None on atomizer encode error. if encoded_cf: encoded_cf.wall_time_ms = int( (wall_time_end - wall_time_start) * 1000) session.add(encoded_cf) wall_time_start = wall_time_end if wall_time_end - last_commit > 10: session.commit() last_commit = wall_time_end
def GetTrainingCorpus( corpus: 'corpuses.Corpus', training_opts: model_pb2.TrainingOptions ) -> typing.Tuple[np.ndarray, np.ndarray, int]: """Get the corpus to train over. Args: corpus: A Corpus instance. training_opts: A TrainingOptions proto. Returns: An X, y pair of data for an epoch, and the number of steps in the epoch. Raises: UserError: If batch_size and sequence_length are too large for the corpus, yielding no batches. """ start_time = time.time() encoded_corpus = corpus.GetTrainingData( shuffle=training_opts.shuffle_corpus_contentfiles_between_epochs) corpus_length = len(encoded_corpus) steps_per_epoch = (corpus_length - 1) // (training_opts.batch_size * training_opts.sequence_length) if not steps_per_epoch: raise errors.UserError( f'Requested batch size ({training_opts.batch_size}) and ' f'sequence length ({training_opts.sequence_length}) are too large for ' f'corpus of size {corpus_length}.') clipped_corpus_length = (steps_per_epoch * training_opts.batch_size * training_opts.sequence_length) x = np.reshape(encoded_corpus[:clipped_corpus_length], [ training_opts.batch_size, steps_per_epoch * training_opts.sequence_length ]) y = np.reshape(encoded_corpus[1:clipped_corpus_length + 1], [ training_opts.batch_size, steps_per_epoch * training_opts.sequence_length ]) app.Log(1, 'Encoded corpus of %s tokens (clipped last %s tokens) in %s ms.', humanize.Commas(clipped_corpus_length), humanize.Commas(corpus_length - clipped_corpus_length), humanize.Commas(int((time.time() - start_time) * 1000))) return x, y, steps_per_epoch
def CreateBatches(self) -> None: start_time = time.time() # generate a kernel corpus self.i = 0 if (self.encoded_corpus is None or self.training_opts.shuffle_corpus_contentfiles_between_epochs): self.encoded_corpus = self.corpus.GetTrainingData( shuffle=self.training_opts. shuffle_corpus_contentfiles_between_epochs) batch_size = self.training_opts.batch_size sequence_length = self.training_opts.sequence_length # set corpus size and number of batches self.num_batches = int( len(self.encoded_corpus) / (batch_size * sequence_length)) if self.num_batches == 0: raise errors.UserError( "Not enough data. Use a smaller sequence_length and batch_size" ) # split into batches clipped_corpus_length = self.num_batches * batch_size * sequence_length clipped_corpus = self.encoded_corpus[:clipped_corpus_length] xdata = clipped_corpus ydata = np.copy(clipped_corpus) # Wrap-around. ydata[:-1] = xdata[1:] ydata[-1] = xdata[0] self.batches = [ DataBatch(x, y) for x, y in zip( np.split(xdata.reshape(batch_size, -1), self.num_batches, 1), np.split(ydata.reshape(batch_size, -1), self.num_batches, 1)) ] app.Log( 1, 'Encoded corpus of %s tokens (clipped last %s tokens) in %s ms.', humanize.Commas(clipped_corpus_length), humanize.Commas(len(self.encoded_corpus) - clipped_corpus_length), humanize.Commas(int((time.time() - start_time) * 1000)))
def Preprocess(contentfiles: pathlib.Path, outdir: pathlib.Path, preprocessor_names): # Error early if preprocessors are bad. [preprocessors.GetPreprocessorFunction(f) for f in preprocessor_names] # This is basically the same code as: # deeplearning.clgen.corpuses.preprocessed.PreprocessedContentFiles:Import() # Only it's writing the results of preprocessing to files rather than to a # database. Consider refactoring. relpaths = {f.name for f in contentfiles.iterdir()} done = {f.name for f in outdir.iterdir()} todo = relpaths - done app.Log(1, 'Preprocessing %s of %s content files', humanize.Commas(len(todo)), humanize.Commas(len(relpaths))) jobs = [ internal_pb2.PreprocessorWorker(contentfile_root=str(contentfiles), relpath=t, preprocessors=preprocessor_names) for t in todo ] pool = multiprocessing.Pool() bar = progressbar.ProgressBar(max_value=len(jobs)) wall_time_start = time.time() workers = pool.imap_unordered(preprocessed.PreprocessorWorker, jobs) succeeded_count = 0 for preprocessed_cf in bar(workers): wall_time_end = time.time() preprocessed_cf.wall_time_ms = (int( (wall_time_end - wall_time_start) * 1000)) wall_time_start = wall_time_end if preprocessed_cf.preprocessing_succeeded: succeeded_count += 1 with open(outdir / preprocessed_cf.input_relpath, 'w') as f: f.write(preprocessed_cf.text) app.Log(1, "Successfully preprocessed %s of %s files (%.2f %%)", humanize.Commas(succeeded_count), humanize.Commas(len(todo)), (succeeded_count / min(len(todo), 1)) * 100)
def _DoHash(self, absolute_path: pathlib.Path, last_modified: int, hash_fn: typing.Callable[[pathlib.Path], str]) -> str: with self.Session() as session: cached_entry = session.query(HashCacheRecord).filter( HashCacheRecord.absolute_path == str(absolute_path)).first() if cached_entry and cached_entry.last_modified == last_modified: app.Log(2, "Cache hit: '%s'", absolute_path) return cached_entry.hash elif cached_entry: app.Log(2, "Cache miss: '%s'", absolute_path) session.delete(cached_entry) start_time = time.time() checksum = hash_fn(absolute_path) app.Log(2, "New cache entry '%s' in %s ms.", absolute_path, humanize.Commas(int((time.time() - start_time) * 1000))) new_entry = HashCacheRecord(absolute_path=str(absolute_path), last_modified=last_modified, hash=checksum) session.add(new_entry) session.commit() return new_entry.hash
def Train(self, **kwargs) -> 'Model': """Train the model. Returns: The model instance. Raises: UnableToAcquireLockError: If the model is locked (i.e. there is another process currently modifying the model). """ self.corpus.Create() with self.training_lock.acquire(): self.backend.Train(self.corpus, **kwargs) telemetry_logs = self.TrainingTelemetry()[:self.config.training. num_epochs] final_loss = telemetry_logs[-1].loss total_time_ms = sum(t.epoch_wall_time_ms for t in telemetry_logs) app.Log( 1, 'Trained model for %d epochs in %s ms (%s). ' 'Training loss: %f.', self.config.training.num_epochs, humanize.Commas(total_time_ms), humanize.Duration(total_time_ms / 1000), final_loss) return self
def ExportGitHistoryForFiles(source: git.Repo, destination: git.Repo, files_of_interest: typing.Set[str], head_ref: str = 'HEAD', resume_export: bool = True) -> int: """Apply the parts of the git history from the given source repo """ if destination.is_dirty(): raise OSError("Repo `{destination.working_tree_dir}` is dirty") with TemporaryGitRemote(destination, source.working_tree_dir) as remote: destination.remote(remote).fetch() tail = None if resume_export: tail = MaybeGetHexShaOfLastExportedCommit(destination) commits_in_order = GetCommitsInOrder(source, head_ref=head_ref, tail_ref=tail) if not commits_in_order: app.Log(1, 'Nothing to export!') return 0 app.Log(1, 'Exporting history from %s commits', humanize.Commas(len(commits_in_order))) return ExportCommitsThatTouchFiles(commits_in_order, destination, files_of_interest)
def ResolveContentId(config: corpus_pb2.Corpus, hc: typing.Optional[hashcache.HashCache] = None) -> str: """Compute the hash of the input contentfiles. This function resolves the unique sha1 checksum of a set of content files. Args: config: The corpus config proto. hc: A hashcache database instance, used for resolving directory hashes. If the corpus has pre_encoded_corpus_url field set, this may be omitted. Returns: A hex encoded sha1 string. """ # We can take a massive shortcut if the content ID is already set in the # config proto. if config.HasField('content_id'): # TODO(github.com/ChrisCummins/phd/issues/46): Refactor this after splitting # out Corpus class. return config.content_id elif config.HasField('pre_encoded_corpus_url'): # TODO(github.com/ChrisCummins/phd/issues/46): Refactor this after splitting # out Corpus class. return crypto.sha1_str(config.pre_encoded_corpus_url) start_time = time.time() if config.HasField('local_directory'): local_directory = ExpandConfigPath( config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix) # After the first time we compute the hash of a directory, we write it into # a file. This is a shortcut to work around the fact that computing the # directory checksum is O(n) with respect to the number of files in the # directory (even if the directory is already cached by the hash cache). # This means that it is the responsibility of the user to delete this cached # file if the directory is changed. hash_file_path = pathlib.Path(str(local_directory) + '.sha1.txt') if hash_file_path.is_file(): app.Log(1, "Reading directory hash: '%s'.", hash_file_path) with open(hash_file_path) as f: content_id = f.read().rstrip() else: # No hash file, so compute the directory hash and create it. try: content_id = hc.GetHash(local_directory) except FileNotFoundError as e: raise errors.UserError(e) # Create the hash file in the directory so that next time we don't need # to reference the hash cache. with open(hash_file_path, 'w') as f: print(content_id, file=f) app.Log(1, "Wrote directory hash: '%s'.", hash_file_path) elif config.HasField('local_tar_archive'): # This if not an efficient means of getting the hash, as it requires always # unpacking the archive and reading the entire contents. It would be nicer # to maintain a cache which maps the mtime of tarballs to their content ID, # similart to how local_directory is implemented. content_id = GetHashOfArchiveContents( ExpandConfigPath( config.local_tar_archive, path_prefix=FLAGS.clgen_local_path_prefix)) else: raise NotImplementedError('Unsupported Corpus.contentfiles field value') app.Log(2, 'Resolved Content ID %s in %s ms.', content_id, humanize.Commas(int((time.time() - start_time) * 1000))) return content_id
def Train(self, corpus, **unused_kwargs) -> 'keras.models.Sequential': """Locked training. If there are cached epoch checkpoints, the one closest to the target number of epochs will be loaded, and the model will be trained for only the remaining number of epochs, if any. This means that calling this function twice will only actually train the model the first time, and all subsequent calls will be no-ops. This method must only be called when the model is locked. Returns: The trained Keras model. """ del unused_kwargs model = builders.BuildKerasModel(self.config, self.atomizer.vocab_size) with open(self.cache.keypath('model.yaml'), 'w') as f: f.write(model.to_yaml()) model.compile( loss='categorical_crossentropy', optimizer=builders.BuildOptimizer(self.config)) # Print a model summary. buf = io.StringIO() model.summary(print_fn=lambda x: buf.write(x + '\n')) app.Log(1, 'Model summary:\n%s', buf.getvalue()) # TODO(cec): Add an atomizer.CreateVocabularyFile() method, with frequency # counts for a given corpus. def Escape(token: str) -> str: """Make a token visible and printable.""" if token == '\t': return '\\t' elif token == '\n': return '\\n' elif not token.strip(): return f"'{token}'" else: return token if not (self.cache.path / 'embeddings' / 'metadata.tsv').is_file(): with open(self.cache.path / 'embeddings' / 'metadata.tsv', 'w') as f: for _, token in sorted( self.atomizer.decoder.items(), key=lambda x: x[0]): f.write(Escape(token) + '\n') target_num_epochs = self.config.training.num_epochs starting_epoch = 0 epoch_checkpoints = self.epoch_checkpoints if len(epoch_checkpoints) >= target_num_epochs: # We have already trained a model to at least this number of epochs, so # simply the weights from that epoch and call it a day. app.Log(1, 'Loading weights from %s', epoch_checkpoints[target_num_epochs - 1]) model.load_weights(epoch_checkpoints[target_num_epochs - 1]) return model # Now entering the point at which training is inevitable. with logutil.TeeLogsToFile('train', self.cache.path / 'logs'): # Deferred importing of Keras so that we don't have to activate the # TensorFlow backend every time we import this module. import keras if epoch_checkpoints: # We have already trained a model at least part of the way to our target # number of epochs, so load the most recent one. starting_epoch = len(epoch_checkpoints) app.Log(1, 'Resuming training from epoch %d.', starting_epoch) model.load_weights(epoch_checkpoints[-1]) callbacks = [ keras.callbacks.ModelCheckpoint( str(self.cache.path / 'checkpoints' / '{epoch:03d}.hdf5'), verbose=1, mode="min", save_best_only=False), keras.callbacks.TensorBoard( str(self.cache.path / 'embeddings'), write_graph=True, embeddings_freq=1, embeddings_metadata={ 'embedding_1': str(self.cache.path / 'embeddings' / 'metadata.tsv'), }), telemetry.TrainingLogger( self.cache.path / 'logs').KerasCallback(keras), ] generator = data_generators.AutoGenerator(corpus, self.config.training) steps_per_epoch = (corpus.encoded.token_count - 1) // ( self.config.training.batch_size * self.config.training.sequence_length) app.Log( 1, 'Step counts: %s per epoch, %s left to do, %s total', humanize.Commas(steps_per_epoch), humanize.Commas( (target_num_epochs - starting_epoch) * steps_per_epoch), humanize.Commas(target_num_epochs * steps_per_epoch)) model.fit_generator( generator, steps_per_epoch=steps_per_epoch, callbacks=callbacks, initial_epoch=starting_epoch, epochs=target_num_epochs) return model
def InitTfGraph(self, sampler: typing.Optional[samplers.Sampler] = None) -> 'tf': """Instantiate a TensorFlow graph for training or inference. The tensorflow graph is different for training and inference, so must be reset when switching between modes. Args: sampler: If set, initialize the model for inference using the given sampler. If not set, initialize model for training. Returns: The imported TensorFlow module. """ start_time = time.time() # Quiet tensorflow. # See: https://github.com/tensorflow/tensorflow/issues/1258 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Deferred importing of TensorFlow. import tensorflow as tf import tensorflow.contrib.seq2seq as seq2seq from tensorflow.contrib import rnn from deeplearning.clgen.models import helper cell_type = { model_pb2.NetworkArchitecture.LSTM: rnn.LSTMBlockCell, model_pb2.NetworkArchitecture.GRU: rnn.GRUBlockCellV2, model_pb2.NetworkArchitecture.RNN: rnn.BasicRNNCell, }.get(self.config.architecture.neuron_type, None) if cell_type is None: raise NotImplementedError # Reset the graph when switching between training and inference. tf.reset_default_graph() if sampler: sequence_length = sampler.sequence_length batch_size = sampler.batch_size else: sequence_length = self.config.training.sequence_length batch_size = self.config.training.batch_size vocab_size = self.atomizer.vocab_size cells_lst = [] for _ in range(self.config.architecture.num_layers): cells_lst.append( cell_type(self.config.architecture.neurons_per_layer)) self.cell = cell = rnn.MultiRNNCell(cells_lst, state_is_tuple=True) self.input_data = tf.placeholder(tf.int32, [batch_size, sequence_length]) self.targets = tf.placeholder(tf.int32, [batch_size, sequence_length]) self.initial_state = self.cell.zero_state(batch_size, tf.float32) self.temperature = tf.Variable(1.0, trainable=False) self.seed_length = tf.Variable(32, trainable=False) if sampler: self.lengths = tf.placeholder(tf.int32, [batch_size]) else: self.lengths = tf.fill([batch_size], sequence_length) scope_name = 'rnnlm' with tf.variable_scope(scope_name): with tf.device('/cpu:0'): embedding = tf.get_variable( 'embedding', [vocab_size, self.config.architecture.neurons_per_layer]) inputs = tf.nn.embedding_lookup(embedding, self.input_data) if sampler: decode_helper = helper.CustomInferenceHelper( inputs, self.lengths, self.seed_length, embedding, self.temperature) else: decode_helper = seq2seq.TrainingHelper(inputs, self.lengths, time_major=False) decoder = seq2seq.BasicDecoder(cell, decode_helper, self.initial_state, tf.layers.Dense(vocab_size)) outputs, self.final_state, _ = seq2seq.dynamic_decode( decoder, output_time_major=False, impute_finished=True, swap_memory=True, scope=scope_name) self.generated = outputs.sample_id self.logits = outputs.rnn_output sequence_weigths = tf.ones([batch_size, sequence_length]) self.loss = seq2seq.sequence_loss(self.logits, self.targets, sequence_weigths) self.learning_rate = tf.Variable(0.0, trainable=False) self.epoch = tf.Variable(0, trainable=False) trainable_variables = tf.trainable_variables() # TODO(cec): Support non-adam optimizers. grads, _ = tf.clip_by_global_norm( tf.gradients(self.loss, trainable_variables, aggregation_method=2), self.config.training.adam_optimizer.normalized_gradient_clip_micros / 1e6) optimizer = tf.train.AdamOptimizer(self.learning_rate) self.train_op = optimizer.apply_gradients( zip(grads, trainable_variables)) if not sampler: # Create tensorboard summary writers for training progress. tf.summary.scalar('loss', self.loss) tf.summary.scalar('learning_rate', self.learning_rate) tf.summary.scalar('epoch_num', self.epoch) num_trainable_params = int( np.sum([np.prod(v.shape) for v in tf.trainable_variables()])) app.Log( 1, 'Instantiated TensorFlow graph with %s trainable parameters ' 'in %s ms.', humanize.Commas(num_trainable_params), humanize.Commas(int((time.time() - start_time) * 1000))) return tf
def Train(self, corpus, test_sampler: typing.Optional[samplers.Sampler] = None, **unused_kwargs) -> None: """Locked training. If there are cached epoch checkpoints, the one closest to the target number of epochs will be loaded, and the model will be trained for only the remaining number of epochs, if any. This means that calling this function twice will only actually train the model the first time, and all subsequent calls will be no-ops. This method must only be called when the model is locked. """ del unused_kwargs if self.is_trained: return data_generator = data_generators.TensorflowBatchGenerator( corpus, self.config.training) tf = self.InitTfGraph() logger = telemetry.TrainingLogger(self.cache.path / 'logs') # Create and merge the tensorboard summary ops. merged = tf.summary.merge_all() # training options # TODO(cec): Enable support for multiple optimizers: initial_learning_rate = ( self.config.training.adam_optimizer.initial_learning_rate_micros / 1e6) decay_rate = (self.config.training.adam_optimizer. learning_rate_decay_per_epoch_micros / 1e6) # # resume from prior checkpoint ckpt_path, ckpt_paths = None, None if (self.cache.path / 'checkpoints' / 'checkpoint').exists(): checkpoint_state = tf.train.get_checkpoint_state(self.cache.path / 'checkpoints') assert checkpoint_state assert checkpoint_state.model_checkpoint_path ckpt_path, ckpt_paths = self.GetParamsPath(checkpoint_state) with tf.Session() as sess: tf.global_variables_initializer().run() # Keep all checkpoints. saver = tf.train.Saver(tf.global_variables(), max_to_keep=100, save_relative_paths=True) # restore model from closest checkpoint. if ckpt_path: app.Log(1, "Restoring checkpoint {}".format(ckpt_path)) saver.restore(sess, ckpt_path) # make sure we don't lose track of other checkpoints if ckpt_paths: saver.recover_last_checkpoints(ckpt_paths) # Offset epoch counts by 1 so that they are in the range [1..n] current_epoch = sess.run(self.epoch) + 1 max_epoch = self.config.training.num_epochs + 1 # Per-epoch training loop. for epoch_num in range(current_epoch, max_epoch): logger.EpochBeginCallback() # decay and set learning rate new_learning_rate = initial_learning_rate * ( (float(100 - decay_rate) / 100.0)**(epoch_num - 1)) sess.run(tf.assign(self.learning_rate, new_learning_rate)) sess.run(tf.assign(self.epoch, epoch_num)) # TODO(cec): refactor data generator to a Python generator. data_generator.CreateBatches() app.Log(1, 'Epoch %d/%d:', epoch_num, self.config.training.num_epochs) state = sess.run(self.initial_state) # Per-batch inner loop. bar = progressbar.ProgressBar( max_value=data_generator.num_batches) for i in bar(range(data_generator.num_batches)): x, y = data_generator.NextBatch() feed = {self.input_data: x, self.targets: y} for j, (c, h) in enumerate(self.initial_state): feed[c], feed[h] = state[j].c, state[j].h summary, loss, state, _ = sess.run( [merged, self.loss, self.final_state, self.train_op], feed) # Periodically write progress to tensorboard. if i % FLAGS.clgen_tf_backend_tensorboard_summary_step_count == 0: step = (epoch_num - 1) * data_generator.num_batches + i self.summary_writer.add_summary(summary, step) # Log the loss and delta. app.Log(1, 'Loss: %.6f.', loss) # Save after every epoch. start_time = time.time() global_step = epoch_num checkpoint_prefix = (self.cache.path / 'checkpoints' / 'checkpoint') checkpoint_path = saver.save(sess, checkpoint_prefix, global_step=global_step) app.Log( 1, 'Saved checkpoint %s in %s ms.', checkpoint_path, humanize.Commas(int((time.time() - start_time) * 1000))) assert pathlib.Path( f'{checkpoint_prefix}-{global_step}.index').is_file() assert pathlib.Path( f'{checkpoint_prefix}-{global_step}.meta').is_file() logger.EpochEndCallback(epoch_num, loss) # If we have a sampler that we can use at the end of epochs, then # This is confusing logic! Consider a refactor to simplify things. if test_sampler: break else: return if test_sampler: self._EndOfEpochTestSample(corpus, test_sampler, step) self.Train(corpus, test_sampler)