def __init__(self, config: active_learning_pb2.ActiveLearner): """Instantiate a model. Args: config: A Model message. Raises: TypeError: If the config argument is not a Model proto. UserError: In case on an invalid config. """ # Error early, so that a cache isn't created. if not isinstance(config, active_learning_pb2.ActiveLearner): t = type(config).__name__ raise TypeError( f"Config must be an ActiveLearner proto. Received: '{t}'") self.config = active_learning_pb2.ActiveLearner() # Validate config options. self.config.CopyFrom(AssertConfigIsValid(config)) distrib.lock() self.cache = cache.mkcache("active_model") distrib.unlock() self.downstream_task = downstream_tasks.DownstreamTask.FromTask( self.config.downstream_task, self.config.training_corpus) if environment.WORLD_RANK == 0: ## Store current commit commit.saveCommit(self.cache.path) self.backend = active_committee.QueryByCommittee( self.config, self.cache, self.downstream_task) l.logger().info("Initialized {} in {}".format(self.backend, self.cache.path)) return
def __init__(self, url: str, is_pre_train: bool = False, must_exist: bool = False, is_replica = False): self.is_pre_train = is_pre_train if environment.WORLD_RANK == 0 or is_replica: encoded_path = pathlib.Path(url.replace("sqlite:///", "")).parent self.length_monitor = monitors.CumulativeHistMonitor(encoded_path, "encoded_kernel_length") if not self.is_pre_train: self.token_monitor = monitors.NormalizedFrequencyMonitor(encoded_path, "token_distribution") self.feature_monitors = {ftype: monitors.CategoricalDistribMonitor(encoded_path, "{}_distribution".format(ftype)) for ftype in extractor.extractors.keys()} super(EncodedContentFiles, self).__init__(url, Base, must_exist=must_exist) if environment.WORLD_SIZE > 1 and not is_replica: # Conduct engine connections to replicated preprocessed chunks. self.base_path = pathlib.Path(url.replace("sqlite:///", "")).resolve().parent hash_id = self.base_path.name try: tdir = pathlib.Path(FLAGS.local_filesystem).resolve() / hash_id / "node_encoded" except Exception: tdir = pathlib.Path("/tmp").resolve() / hash_id / "node_encoded" distrib.lock() tdir.mkdir(parents = True, exist_ok = True) distrib.unlock() self.replicated_path = tdir / "encoded_{}.db".format(environment.WORLD_RANK) self.replicated = EncodedContentFiles( url = "sqlite:///{}".format(str(self.replicated_path)), is_pre_train = is_pre_train, must_exist = must_exist, is_replica = True ) self.length_monitor = self.replicated.length_monitor if not self.is_pre_train: self.token_monitor = self.replicated.token_monitor self.feature_monitors = self.replicated.feature_monitors distrib.barrier() return
def __init__( self, path: pathlib.Path, must_exist: bool = False, flush_secs: int = 30, plot_sample_status = False, commit_sample_frequency: int = 1024, ): distrib.lock() self.db = samples_database.SamplesDatabase("sqlite:///{}".format(str(path)), must_exist = must_exist) distrib.unlock() self.sample_id = self.db.count self.visited = set(self.db.get_hash_entries) self.flush_queue = [] self.plot_sample_status = plot_sample_status if self.plot_sample_status: self.saturation_monitor = monitors.CumulativeHistMonitor(path.parent, "cumulative_sample_count")
def __init__(self, url: str, must_exist: bool = False, is_replica = False): if environment.WORLD_RANK == 0 or is_replica: super(PreprocessedContentFiles, self).__init__( url, Base, must_exist=must_exist ) if environment.WORLD_SIZE > 1 and not is_replica: # Conduct engine connections to replicated preprocessed chunks. self.base_path = pathlib.Path(url.replace("sqlite:///", "")).resolve().parent hash_id = self.base_path.name try: tdir = pathlib.Path(FLAGS.local_filesystem).resolve() / hash_id / "node_preprocessed" except Exception: tdir = pathlib.Path("/tmp").resolve() / hash_id / "node_preprocessed" distrib.lock() tdir.mkdir(parents = True, exist_ok = True) distrib.unlock() self.replicated_path = tdir / "preprocessed_{}.db".format(environment.WORLD_RANK) self.replicated = PreprocessedContentFiles( url = "sqlite:///{}".format(str(self.replicated_path)), must_exist = must_exist, is_replica = True ) distrib.barrier() return
def __init__(self, config: model_pb2.Model): """Instantiate a model. Args: config: A Model message. Raises: TypeError: If the config argument is not a Model proto. UserError: In case on an invalid config. """ # Error early, so that a cache isn't created. if not isinstance(config, model_pb2.Model): t = type(config).__name__ raise TypeError(f"Config must be a Model proto. Received: '{t}'") self.config = model_pb2.Model() # Validate config options. self.config.CopyFrom(builders.AssertIsBuildable(config)) if FLAGS.num_train_steps: self.config.training.num_train_steps = FLAGS.num_train_steps if FLAGS.num_pretrain_steps: self.config.training.num_pretrain_steps = FLAGS.num_pretrain_steps if FLAGS.num_epochs: self.config.training.num_epochs = FLAGS.num_epochs # Initialize distrib lock path. if environment.WORLD_SIZE > 1: if environment.WORLD_RANK == 0: lock_cache = cache.mkcache("locks") lock_cache.path.mkdir(exist_ok=True) else: while not cache.cachepath("locks").exists(): time.sleep(0.5) lock_cache = cache.mkcache("locks") distrib.init(lock_cache.path) # Initialize corpuses self.corpus = corpuses.Corpus(config.corpus) self.pre_train_corpus = None if config.HasField("pre_train_corpus"): self.pre_train_corpus = corpuses.Corpus(config.pre_train_corpus) self.hash = self._ComputeHash(self.pre_train_corpus, self.corpus, self.config) self._created = False distrib.lock() self.cache = cache.mkcache("model", self.hash) distrib.unlock() if environment.WORLD_RANK == 0: # Create the necessary cache directories. (self.cache.path / "checkpoints").mkdir(exist_ok=True) (self.cache.path / "samples").mkdir(exist_ok=True) # Create symlink to encoded corpus. symlink = self.cache.path / "corpus" if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path(self.corpus.encoded.url[len("sqlite:///" ):]).parent, self.cache.path, ), symlink, ) if self.pre_train_corpus: symlink = self.cache.path / "pre_train_corpus" if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path(self.pre_train_corpus.encoded. url[len("sqlite:///"):]).parent, self.cache.path, ), symlink, ) # Create symlink to the tokenizer and create a backup inside checkpoints. symlink = self.cache.path / "tokenizer" if not symlink.is_symlink(): os.symlink( os.path.relpath(self.corpus.tokenizer_path, self.cache.path), symlink) if (self.cache.path / "checkpoints" / "backup_tokenizer.pkl").exists(): shutil.copyfile( self.cache.path / "checkpoints" / "backup_tokenizer.pkl", self.corpus.tokenizer_path) # Validate metadata against cache. if self.cache.get("META.pbtxt"): cached_meta = pbutil.FromFile( pathlib.Path(self.cache["META.pbtxt"]), internal_pb2.ModelMeta()) # Exclude num_epochs and corpus location from metadata comparison. config_to_compare = model_pb2.Model() config_to_compare.CopyFrom(self.config) config_to_compare.corpus.ClearField("contentfiles") if config_to_compare.HasField("pre_train_corpus"): config_to_compare.pre_train_corpus.ClearField( "contentfiles") config_to_compare.training.ClearField("num_epochs") config_to_compare.training.ClearField("num_train_steps") if config_to_compare.HasField("pre_train_corpus"): config_to_compare.training.ClearField("num_pretrain_steps") config_to_compare.training.ClearField("batch_size") if config_to_compare.training.HasField("data_generator"): config_to_compare.training.data_generator.ClearField( "steps_per_epoch") config_to_compare.training.data_generator.ClearField( "validation_set") # These fields should have already been cleared, but we'll do it again # so that metadata comparisons don't fail when the cached meta schema # is updated. cached_to_compare = model_pb2.Model() cached_to_compare.CopyFrom(cached_meta.config) cached_to_compare.corpus.ClearField("contentfiles") if cached_to_compare.HasField("pre_train_corpus"): cached_to_compare.pre_train_corpus.ClearField( "contentfiles") cached_to_compare.training.ClearField("num_epochs") cached_to_compare.training.ClearField("num_train_steps") if cached_to_compare.HasField("pre_train_corpus"): cached_to_compare.training.ClearField("num_pretrain_steps") cached_to_compare.training.ClearField("batch_size") if cached_to_compare.training.HasField("data_generator"): cached_to_compare.training.data_generator.ClearField( "steps_per_epoch") cached_to_compare.training.data_generator.ClearField( "validation_set") if cached_to_compare.training.sequence_length != config_to_compare.training.sequence_length: l.logger().warning( "Mismatch between pre-trained and current config sequence_length!\ This can only be intended in BERT model!") cached_to_compare.training.ClearField("sequence_length") config_to_compare.training.ClearField("sequence_length") if config_to_compare != cached_to_compare: raise SystemError("Metadata mismatch: {} \n\n {}".format( config_to_compare, cached_to_compare)) self.meta = cached_meta else: self.meta = internal_pb2.ModelMeta() self.meta.config.CopyFrom(self.config) self._WriteMetafile() ## Store current commit commit.saveCommit(self.cache.path) self.backend = { model_pb2.NetworkArchitecture.TENSORFLOW_SEQ: tf_sequential.tfSequential, model_pb2.NetworkArchitecture.KERAS_SEQ: keras_sequential.kerasSequential, model_pb2.NetworkArchitecture.TENSORFLOW_BERT: tf_bert.tfBert, model_pb2.NetworkArchitecture.TORCH_BERT: torch_bert.torchBert, }[config.architecture.backend](self.config, self.cache, self.hash) l.logger().info("Initialized {} in {}".format(self.backend, self.cache.path)) return
def __init__(self, config: sampler_pb2.Sampler, sample_db_name = "samples.db"): """Instantiate a sampler. Args: config: A Sampler message. Raises: TypeError: If the config argument is not a Sampler proto. UserError: If the config contains invalid values. """ if not isinstance(config, sampler_pb2.Sampler): t = type(config).__name__ raise TypeError(f"Config must be a Sampler proto. Received: '{t}'") self.config = sampler_pb2.Sampler() self.config.CopyFrom(AssertConfigIsValid(config)) self.hash = self._ComputeHash(self.config) self.terminators = GetTerminationCriteria(self.config.termination_criteria) if config.HasField("start_text"): self.start_text = self.config.start_text else: self.start_text = "" self.temperature = self.config.temperature_micros / 1e6 self.batch_size = self.config.batch_size self.sequence_length = self.config.sequence_length self.sample_db_name = sample_db_name # Create the necessary cache directories. distrib.lock() self.cache = cache.mkcache("sampler", self.hash) distrib.unlock() self.samples_directory = self.cache.path / "samples" if environment.WORLD_RANK == 0: self.samples_directory.mkdir(exist_ok = True) self.corpus_directory = None self.sample_corpus = None if self.config.HasField("sample_corpus"): self.corpus_directory = self.cache.path / "sample_corpus" if environment.WORLD_RANK == 0: self.corpus_directory.mkdir(exist_ok = True) if self.config.sample_corpus.HasField("corpus"): self.sample_corpus = corpuses.Corpus(self.config.sample_corpus.corpus) self.sample_corpus.Create() self.symlinkSampleCorpus( pathlib.Path(self.sample_corpus.encoded.url[len("sqlite:///") :]).parent ) text_data = [ self.sample_corpus.tokenizer.tokensToString(x) for x in self.sample_corpus.GetTrainingData() ] else: self.start_text = self.config.sample_corpus.start_text text_data = [self.start_text] # Text data is dumped in order to specialize with all different model tokenizers. if environment.WORLD_RANK == 0: with open(self.cache.path / "sample_corpus" / "text_corpus.pkl", 'wb') as outf: pickle.dump(text_data, outf) if self.has_active_learning: self.active_learner = active_models.Model(config.sample_corpus.corpus_config.active.active_learner) if environment.WORLD_RANK == 0: meta = internal_pb2.SamplerMeta() meta.config.CopyFrom(self.config) pbutil.ToFile(meta, path = self.cache.path / "META.pbtxt") commit.saveCommit(self.cache.path) # Set in Specialize(). self.encoded_start_text = None self.tokenized_start_text = None