def _ComputeHash(config: sampler_pb2.Sampler) -> str: """Compute sampler hash. The hash is computed from the serialized representation of the config proto. """ return crypto.sha1(config.SerializeToString())
def test_Sampler_batch_size(abc_sampler_config: sampler_pb2.Sampler): """Test that batch_size is set from Sampler proto.""" abc_sampler_config.batch_size = 99 s = samplers.Sampler(abc_sampler_config) assert 99 == s.batch_size
def AssertConfigIsValid(config: sampler_pb2.Sampler) -> sampler_pb2.Sampler: """Assert that a sampler configuration contains no invalid values. Args: config: A sampler configuration proto. Returns: The sampler configuration proto. Raises: UserError: If there are configuration errors. """ try: if config.HasField("start_text"): pbutil.AssertFieldConstraint( config, "start_text", lambda s: len(s), "Sampler.start_text must be a string", ) elif config.HasField("sample_corpus"): if config.sample_corpus.HasField("corpus_config"): if config.sample_corpus.corpus_config.HasField("normal"): pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "normal") elif config.sample_corpus.corpus_config.HasField("online"): pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "online") elif config.sample_corpus.corpus_config.HasField("active"): pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_limit_per_feed") pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_search_depth") pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_search_width") pbutil.AssertFieldConstraint( config.sample_corpus.corpus_config.active, "batch_size_per_feed", lambda x : config.batch_size % x == 0, "batch_size {} must be a multiple of batch_size_per_feed".format( config.sample_corpus.corpus_config.active, config.batch_size ) ) pbutil.AssertFieldConstraint( config.sample_corpus.corpus_config.active, "feature_space", lambda x : x in set(extractor.extractors.keys()), "feature_space can only be one of {}".format(', '.join(list(extractor.extractors.keys()))) ) if config.sample_corpus.corpus_config.active.HasField("target"): pbutil.AssertFieldConstraint( config.sample_corpus.corpus_config.active, "target", lambda x : x in set(feature_sampler.targets.keys()), "target can only be one of {}".format(', '.join(list(feature_sampler.targets.keys()))) ) elif config.sample_corpus.corpus_config.active.HasField("active_learner"): active_models.AssertConfigIsValid(config.sample_corpus.corpus_config.active.active_learner) else: raise ValueError(config.sample_corpus.corpus_config.active) else: raise ValueError("Sampling type is undefined: {}".format(config.sample_corpus.corpus_config)) pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "max_predictions_per_seq") pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "masked_lm_prob") pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "mask_technique") if config.sample_corpus.corpus_config.HasField("mask"): pbutil.AssertFieldIsSet( config.sample_corpus.corpus_config.mask, "random_placed_mask", ) elif config.sample_corpus.corpus_config.HasField("hole"): if config.sample_corpus.corpus_config.hole.HasField("absolute_length"): pbutil.AssertFieldConstraint( config.sample_corpus.corpus_config.hole, "absolute_length", lambda x : x > 0, "absolute length is the upper bound range of a hole's length. Therefore should be > 0." ) else: pbutil.AssertFieldConstraint( config.sample_corpus.corpus_config.hole, "relative_length", lambda x : 0.0 < x <= 1.0, "relative length must be between 0 and 100% of a kernel's actual length." ) if config.sample_corpus.corpus_config.hole.HasField("normal_distribution"): pbutil.AssertFieldIsSet( config.sample_corpus.corpus_config.hole.normal_distribution, "mean", ) pbutil.AssertFieldIsSet( config.sample_corpus.corpus_config.hole.normal_distribution, "variance", ) elif not config.sample_corpus.corpus_config.hole.HasField("uniform_distribution"): raise ValueError("Hole length distribution has not been set.") elif config.sample_corpus.corpus_config.HasField("mask_seq"): if config.sample_corpus.corpus_config.mask_seq.HasField("absolute_length"): pbutil.AssertFieldConstraint( config.sample_corpus.corpus_config.mask_seq, "absolute_length", lambda x : x > 0, "absolute length is the upper bound range of a mask_seq's length. Therefore should be > 0." ) else: pbutil.AssertFieldConstraint( config.sample_corpus.corpus_config.mask_seq, "relative_length", lambda x : 0.0 < x <= 1.0, "relative length must be between 0 and 100% of a kernel's actual length." ) if config.sample_corpus.corpus_config.mask_seq.HasField("normal_distribution"): pbutil.AssertFieldIsSet( config.sample_corpus.corpus_config.mask_seq.normal_distribution, "mean", ) pbutil.AssertFieldIsSet( config.sample_corpus.corpus_config.mask_seq.normal_distribution, "variance", ) elif not config.sample_corpus.corpus_config.mask_seq.HasField("uniform_distribution"): raise ValueError("Hole length distribution has not been set.") else: raise ValueError("sample_corpus has no corpus_config field.") if config.sample_corpus.HasField("corpus"): corpuses.AssertConfigIsValid(config.sample_corpus.corpus) else: pbutil.AssertFieldIsSet( config.sample_corpus, "start_text" ) elif ((not config.HasField("train_set")) and (not config.HasField("validation_set")) and (not config.HasField("sample_set")) and (not config.HasField("live_sampling"))): raise ValueError(config) pbutil.AssertFieldConstraint( config, "batch_size", lambda x: 0 < x, "Sampler.batch_size must be > 0" ) pbutil.AssertFieldConstraint( config, "sequence_length", lambda x: 0 < x, "Sampler.sequence_length must be > 0", ) pbutil.AssertFieldConstraint( config, "temperature_micros", lambda x: 0 < x, "Sampler.temperature_micros must be > 0", ) return config except pbutil.ProtoValueError as e: raise ValueError(e)
def test_Sampler_temperature(abc_sampler_config: sampler_pb2.Sampler): """Test that temperature is set from Sampler proto.""" abc_sampler_config.temperature_micros = 1000000 s = samplers.Sampler(abc_sampler_config) assert pytest.approx(1.0) == s.temperature
def __init__(self, config: sampler_pb2.Sampler, sample_db_name = "samples.db"): """Instantiate a sampler. Args: config: A Sampler message. Raises: TypeError: If the config argument is not a Sampler proto. UserError: If the config contains invalid values. """ if not isinstance(config, sampler_pb2.Sampler): t = type(config).__name__ raise TypeError(f"Config must be a Sampler proto. Received: '{t}'") self.config = sampler_pb2.Sampler() self.config.CopyFrom(AssertConfigIsValid(config)) self.hash = self._ComputeHash(self.config) self.terminators = GetTerminationCriteria(self.config.termination_criteria) if config.HasField("start_text"): self.start_text = self.config.start_text else: self.start_text = "" self.temperature = self.config.temperature_micros / 1e6 self.batch_size = self.config.batch_size self.sequence_length = self.config.sequence_length self.sample_db_name = sample_db_name # Create the necessary cache directories. distrib.lock() self.cache = cache.mkcache("sampler", self.hash) distrib.unlock() self.samples_directory = self.cache.path / "samples" if environment.WORLD_RANK == 0: self.samples_directory.mkdir(exist_ok = True) self.corpus_directory = None self.sample_corpus = None if self.config.HasField("sample_corpus"): self.corpus_directory = self.cache.path / "sample_corpus" if environment.WORLD_RANK == 0: self.corpus_directory.mkdir(exist_ok = True) if self.config.sample_corpus.HasField("corpus"): self.sample_corpus = corpuses.Corpus(self.config.sample_corpus.corpus) self.sample_corpus.Create() self.symlinkSampleCorpus( pathlib.Path(self.sample_corpus.encoded.url[len("sqlite:///") :]).parent ) text_data = [ self.sample_corpus.tokenizer.tokensToString(x) for x in self.sample_corpus.GetTrainingData() ] else: self.start_text = self.config.sample_corpus.start_text text_data = [self.start_text] # Text data is dumped in order to specialize with all different model tokenizers. if environment.WORLD_RANK == 0: with open(self.cache.path / "sample_corpus" / "text_corpus.pkl", 'wb') as outf: pickle.dump(text_data, outf) if self.has_active_learning: self.active_learner = active_models.Model(config.sample_corpus.corpus_config.active.active_learner) if environment.WORLD_RANK == 0: meta = internal_pb2.SamplerMeta() meta.config.CopyFrom(self.config) pbutil.ToFile(meta, path = self.cache.path / "META.pbtxt") commit.saveCommit(self.cache.path) # Set in Specialize(). self.encoded_start_text = None self.tokenized_start_text = None