def SampleMaskLMBatchGenerator(self, model_opts : model_pb2.TrainingOptions, sampler : "samplers.Sampler", tokenizer : tokenizers.TokenizerBase, seed : int, sample_batch_size : int, max_position_embeddings : int, cache_path : pathlib.Path, feature_encoder : bool = False, feature_tokenizer : tokenizers.FeatureTokenizer = None, feature_sequence_length : int = None, ) -> "data_generator.MaskLMBatchGenerator": """Initializes data generator for inference.""" self.cache = cache.mkcache(cache_path, "dataset") self.cache.path.mkdir(exist_ok = True, parents = True) self.dataset = {} self.sampler = sampler self.corpus = sampler.sample_corpus self.tokenizer = tokenizer self.config = model_opts.data_generator self.rngen = np.random self.sample_batch_size = sample_batch_size self.max_position_embeddings = max_position_embeddings self.feature_encoder = feature_encoder self.feature_tokenizer = feature_tokenizer self.feature_sequence_length = feature_sequence_length self.training_opts = model_opts self.training_opts.sequence_length = sampler.sequence_length self.training_opts.batch_size = sampler.batch_size return self
def __init__(self, config: active_learning_pb2.ActiveLearner): """Instantiate a model. Args: config: A Model message. Raises: TypeError: If the config argument is not a Model proto. UserError: In case on an invalid config. """ # Error early, so that a cache isn't created. if not isinstance(config, active_learning_pb2.ActiveLearner): t = type(config).__name__ raise TypeError( f"Config must be an ActiveLearner proto. Received: '{t}'") self.config = active_learning_pb2.ActiveLearner() # Validate config options. self.config.CopyFrom(AssertConfigIsValid(config)) distrib.lock() self.cache = cache.mkcache("active_model") distrib.unlock() self.downstream_task = downstream_tasks.DownstreamTask.FromTask( self.config.downstream_task, self.config.training_corpus) if environment.WORLD_RANK == 0: ## Store current commit commit.saveCommit(self.cache.path) self.backend = active_committee.QueryByCommittee( self.config, self.cache, self.downstream_task) l.logger().info("Initialized {} in {}".format(self.backend, self.cache.path)) return
def TrainMaskLMBatchGenerator(self, corpus : "corpuses.Corpus", training_opts : model_pb2.TrainingOptions, cache_path : pathlib.Path, num_train_steps : int = None, pre_train : bool = False, feature_encoder : bool = False, feature_tokenizer : tokenizers.FeatureTokenizer = None, feature_sequence_length : int = None, ) -> "data_generator.MaskLMDataGenerator": """Initializes data generator for training.""" self.cache = cache.mkcache(cache_path, "dataset") self.cache.path.mkdir(exist_ok = True, parents = True) self.dataset = {} self.corpus = corpus self.tokenizer = corpus.tokenizer self.config = training_opts.data_generator self.training_opts = training_opts self.rngen = np.random # random.Random(training_opts.random_seed) self.pre_train = pre_train self.feature_encoder = feature_encoder self.feature_tokenizer = feature_tokenizer self.feature_sequence_length = feature_sequence_length if num_train_steps: self.num_train_steps = num_train_steps else: self.num_train_steps = self.training_opts.num_train_steps shaped_corpus = self.createCorpus(self.cache.path) if self.config.datapoint_time == "pre": if self.feature_encoder: raise NotImplementedError("Pre masking corpus does not work with feature encoding model.") # 'pre' pre-processes/masks training/validation/sampling corpus for the model to use. # 'online' stores the raw data and masks them on the fly. self.configDataset(shaped_corpus) return self
def __init__(self, config: model_pb2.Model): """Instantiate a model. Args: config: A Model message. Raises: TypeError: If the config argument is not a Model proto. UserError: In case on an invalid config. """ # Error early, so that a cache isn't created. if not isinstance(config, model_pb2.Model): t = type(config).__name__ raise TypeError(f"Config must be a Model proto. Received: '{t}'") self.config = model_pb2.Model() # Validate config options. self.config.CopyFrom(builders.AssertIsBuildable(config)) if FLAGS.num_train_steps: self.config.training.num_train_steps = FLAGS.num_train_steps if FLAGS.num_pretrain_steps: self.config.training.num_pretrain_steps = FLAGS.num_pretrain_steps if FLAGS.num_epochs: self.config.training.num_epochs = FLAGS.num_epochs # Initialize distrib lock path. if environment.WORLD_SIZE > 1: if environment.WORLD_RANK == 0: lock_cache = cache.mkcache("locks") lock_cache.path.mkdir(exist_ok=True) else: while not cache.cachepath("locks").exists(): time.sleep(0.5) lock_cache = cache.mkcache("locks") distrib.init(lock_cache.path) # Initialize corpuses self.corpus = corpuses.Corpus(config.corpus) self.pre_train_corpus = None if config.HasField("pre_train_corpus"): self.pre_train_corpus = corpuses.Corpus(config.pre_train_corpus) self.hash = self._ComputeHash(self.pre_train_corpus, self.corpus, self.config) self._created = False distrib.lock() self.cache = cache.mkcache("model", self.hash) distrib.unlock() if environment.WORLD_RANK == 0: # Create the necessary cache directories. (self.cache.path / "checkpoints").mkdir(exist_ok=True) (self.cache.path / "samples").mkdir(exist_ok=True) # Create symlink to encoded corpus. symlink = self.cache.path / "corpus" if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path(self.corpus.encoded.url[len("sqlite:///" ):]).parent, self.cache.path, ), symlink, ) if self.pre_train_corpus: symlink = self.cache.path / "pre_train_corpus" if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path(self.pre_train_corpus.encoded. url[len("sqlite:///"):]).parent, self.cache.path, ), symlink, ) # Create symlink to the tokenizer and create a backup inside checkpoints. symlink = self.cache.path / "tokenizer" if not symlink.is_symlink(): os.symlink( os.path.relpath(self.corpus.tokenizer_path, self.cache.path), symlink) if (self.cache.path / "checkpoints" / "backup_tokenizer.pkl").exists(): shutil.copyfile( self.cache.path / "checkpoints" / "backup_tokenizer.pkl", self.corpus.tokenizer_path) # Validate metadata against cache. if self.cache.get("META.pbtxt"): cached_meta = pbutil.FromFile( pathlib.Path(self.cache["META.pbtxt"]), internal_pb2.ModelMeta()) # Exclude num_epochs and corpus location from metadata comparison. config_to_compare = model_pb2.Model() config_to_compare.CopyFrom(self.config) config_to_compare.corpus.ClearField("contentfiles") if config_to_compare.HasField("pre_train_corpus"): config_to_compare.pre_train_corpus.ClearField( "contentfiles") config_to_compare.training.ClearField("num_epochs") config_to_compare.training.ClearField("num_train_steps") if config_to_compare.HasField("pre_train_corpus"): config_to_compare.training.ClearField("num_pretrain_steps") config_to_compare.training.ClearField("batch_size") if config_to_compare.training.HasField("data_generator"): config_to_compare.training.data_generator.ClearField( "steps_per_epoch") config_to_compare.training.data_generator.ClearField( "validation_set") # These fields should have already been cleared, but we'll do it again # so that metadata comparisons don't fail when the cached meta schema # is updated. cached_to_compare = model_pb2.Model() cached_to_compare.CopyFrom(cached_meta.config) cached_to_compare.corpus.ClearField("contentfiles") if cached_to_compare.HasField("pre_train_corpus"): cached_to_compare.pre_train_corpus.ClearField( "contentfiles") cached_to_compare.training.ClearField("num_epochs") cached_to_compare.training.ClearField("num_train_steps") if cached_to_compare.HasField("pre_train_corpus"): cached_to_compare.training.ClearField("num_pretrain_steps") cached_to_compare.training.ClearField("batch_size") if cached_to_compare.training.HasField("data_generator"): cached_to_compare.training.data_generator.ClearField( "steps_per_epoch") cached_to_compare.training.data_generator.ClearField( "validation_set") if cached_to_compare.training.sequence_length != config_to_compare.training.sequence_length: l.logger().warning( "Mismatch between pre-trained and current config sequence_length!\ This can only be intended in BERT model!") cached_to_compare.training.ClearField("sequence_length") config_to_compare.training.ClearField("sequence_length") if config_to_compare != cached_to_compare: raise SystemError("Metadata mismatch: {} \n\n {}".format( config_to_compare, cached_to_compare)) self.meta = cached_meta else: self.meta = internal_pb2.ModelMeta() self.meta.config.CopyFrom(self.config) self._WriteMetafile() ## Store current commit commit.saveCommit(self.cache.path) self.backend = { model_pb2.NetworkArchitecture.TENSORFLOW_SEQ: tf_sequential.tfSequential, model_pb2.NetworkArchitecture.KERAS_SEQ: keras_sequential.kerasSequential, model_pb2.NetworkArchitecture.TENSORFLOW_BERT: tf_bert.tfBert, model_pb2.NetworkArchitecture.TORCH_BERT: torch_bert.torchBert, }[config.architecture.backend](self.config, self.cache, self.hash) l.logger().info("Initialized {} in {}".format(self.backend, self.cache.path)) return
def __init__(self, config: sampler_pb2.Sampler, sample_db_name = "samples.db"): """Instantiate a sampler. Args: config: A Sampler message. Raises: TypeError: If the config argument is not a Sampler proto. UserError: If the config contains invalid values. """ if not isinstance(config, sampler_pb2.Sampler): t = type(config).__name__ raise TypeError(f"Config must be a Sampler proto. Received: '{t}'") self.config = sampler_pb2.Sampler() self.config.CopyFrom(AssertConfigIsValid(config)) self.hash = self._ComputeHash(self.config) self.terminators = GetTerminationCriteria(self.config.termination_criteria) if config.HasField("start_text"): self.start_text = self.config.start_text else: self.start_text = "" self.temperature = self.config.temperature_micros / 1e6 self.batch_size = self.config.batch_size self.sequence_length = self.config.sequence_length self.sample_db_name = sample_db_name # Create the necessary cache directories. distrib.lock() self.cache = cache.mkcache("sampler", self.hash) distrib.unlock() self.samples_directory = self.cache.path / "samples" if environment.WORLD_RANK == 0: self.samples_directory.mkdir(exist_ok = True) self.corpus_directory = None self.sample_corpus = None if self.config.HasField("sample_corpus"): self.corpus_directory = self.cache.path / "sample_corpus" if environment.WORLD_RANK == 0: self.corpus_directory.mkdir(exist_ok = True) if self.config.sample_corpus.HasField("corpus"): self.sample_corpus = corpuses.Corpus(self.config.sample_corpus.corpus) self.sample_corpus.Create() self.symlinkSampleCorpus( pathlib.Path(self.sample_corpus.encoded.url[len("sqlite:///") :]).parent ) text_data = [ self.sample_corpus.tokenizer.tokensToString(x) for x in self.sample_corpus.GetTrainingData() ] else: self.start_text = self.config.sample_corpus.start_text text_data = [self.start_text] # Text data is dumped in order to specialize with all different model tokenizers. if environment.WORLD_RANK == 0: with open(self.cache.path / "sample_corpus" / "text_corpus.pkl", 'wb') as outf: pickle.dump(text_data, outf) if self.has_active_learning: self.active_learner = active_models.Model(config.sample_corpus.corpus_config.active.active_learner) if environment.WORLD_RANK == 0: meta = internal_pb2.SamplerMeta() meta.config.CopyFrom(self.config) pbutil.ToFile(meta, path = self.cache.path / "META.pbtxt") commit.saveCommit(self.cache.path) # Set in Specialize(). self.encoded_start_text = None self.tokenized_start_text = None
def __init__(self, config: typing.Union[corpus_pb2.Corpus, corpus_pb2.PreTrainCorpus]): """Instantiate a corpus from a proto config. If this is a new corpus, a number of files will be created, which may take some time. Args: config: A Corpus message. Raises: TypeError: If the config argument is not a Sampler proto. UserError: In case the corpus is not found, or config contains invalid options. EmptyCorpusException: In case the corpus contains no data. """ if not isinstance(config, corpus_pb2.Corpus) and not isinstance(config, corpus_pb2.PreTrainCorpus): raise TypeError(f"Config must be a Corpus proto. Received: '{type(config).__name__}'") # Make a local copy of the configuration. if isinstance(config, corpus_pb2.Corpus): self.config = corpus_pb2.Corpus() self.pre_train = False else: self.config = corpus_pb2.PreTrainCorpus() self.pre_train = True self.config.CopyFrom(AssertConfigIsValid(config)) self._tokenizer = None self._created = False # An in-memory cache of the encoded contentfiles indices arrays. # Set and used in GetTrainingData(). self._indices_arrays: typing.Optional[typing.List[np.array]] = None if environment.WORLD_RANK == 0: cache.cachepath("corpus").mkdir(parents=True, exist_ok=True) distrib.barrier() self.content_id = ResolveContentId(self.config) # Database of pre-processed files. preprocessed_id = ResolvePreprocessedId(self.content_id, self.config) if environment.WORLD_RANK == 0: cache.cachepath("corpus", "preprocessed", preprocessed_id).mkdir(exist_ok=True, parents=True) distrib.barrier() preprocessed_db_path = cache.cachepath("corpus", "preprocessed", preprocessed_id, "preprocessed.db") if self.config.HasField("content_id") and not preprocessed_db_path.is_file(): raise ValueError(f"Content ID not found: '{self.content_id}'") self.preprocessed = preprocessed.PreprocessedContentFiles( f"sqlite:///{preprocessed_db_path}" ) # Create symlink to contentfiles. if environment.WORLD_RANK == 0: symlink = (pathlib.Path(self.preprocessed.url[len("sqlite:///") :]).parent / "contentfiles") if not symlink.is_symlink(): if config.HasField("local_directory"): os.symlink( str(ExpandConfigPath(config.local_directory, path_prefix=FLAGS.clgen_local_path_prefix)), symlink, ) elif config.HasField("local_tar_archive"): os.symlink( str(ExpandConfigPath(config.local_tar_archive, path_prefix=FLAGS.clgen_local_path_prefix)), symlink, ) elif config.HasField("bq_database"): os.symlink( str(ExpandConfigPath(config.bq_database, path_prefix=FLAGS.clgen_local_path_prefix)), symlink, ) # elif config.HasField("fetch_github"): # os.symlink( # str(ExpandConfigPath(config.fetch_github, path_prefix=FLAGS.clgen_local_path_prefix)), # symlink, # ) distrib.barrier() # Data of encoded pre-preprocessed files. encoded_id = ResolveEncodedId(self.content_id, self.config) if environment.WORLD_RANK == 0: cache.cachepath("corpus", "encoded", encoded_id).mkdir(exist_ok=True, parents=True) distrib.barrier() db_path = cache.cachepath("corpus", "encoded", encoded_id, "encoded.db") if self.config.HasField("pre_encoded_corpus_url"): self.encoded = encoded.EncodedContentFiles(config.pre_encoded_corpus_url, self.pre_train) else: self.encoded = encoded.EncodedContentFiles(f"sqlite:///{db_path}", self.pre_train) self.tokenizer_path = cache.cachepath( "corpus", "encoded", encoded_id, "tokenizer.pkl" ) if environment.WORLD_RANK == 0 and not self.config.HasField("pre_encoded_corpus_url"): symlink = (pathlib.Path(self.encoded.url[len("sqlite:///") :]).parent / "preprocessed") if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path(self.preprocessed.url[len("sqlite:///") :]).parent, pathlib.Path(self.encoded.url[len("sqlite:///") :]).parent, ), symlink, ) self.hash = encoded_id self.cache = cache.mkcache("corpus", "encoded", encoded_id) if environment.WORLD_RANK == 0: commit.saveCommit(self.cache.path) commit.saveCommit(self.cache.path.parent.parent / "preprocessed" / preprocessed_id) distrib.barrier() l.logger().info("Initialized {}train corpus in {}".format("pre_" if self.pre_train else "", self.cache.path)) return