Example #1
0
def AssertConfigIsValid(config: sampler_pb2.Sampler) -> sampler_pb2.Sampler:
  """Assert that a sampler configuration contains no invalid values.

  Args:
    config: A sampler configuration proto.

  Returns:
    The sampler configuration proto.

  Raises:
    UserError: If there are configuration errors.
  """
  try:
    if config.HasField("start_text"):
      pbutil.AssertFieldConstraint(
        config,
        "start_text",
        lambda s: len(s),
        "Sampler.start_text must be a string",
      )
    elif config.HasField("sample_corpus"):
      if config.sample_corpus.HasField("corpus_config"):
        if config.sample_corpus.corpus_config.HasField("normal"):
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "normal")
        elif config.sample_corpus.corpus_config.HasField("online"):
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "online")
        elif config.sample_corpus.corpus_config.HasField("active"):
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_limit_per_feed")
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_search_depth")
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_search_width")
          pbutil.AssertFieldConstraint(
            config.sample_corpus.corpus_config.active,
            "batch_size_per_feed",
            lambda x : config.batch_size % x == 0,
            "batch_size {} must be a multiple of batch_size_per_feed".format(
              config.sample_corpus.corpus_config.active,
              config.batch_size
            )
          )
          pbutil.AssertFieldConstraint(
            config.sample_corpus.corpus_config.active,
            "feature_space",
            lambda x : x in set(extractor.extractors.keys()),
            "feature_space can only be one of {}".format(', '.join(list(extractor.extractors.keys())))
          )
          if config.sample_corpus.corpus_config.active.HasField("target"):
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.active,
              "target",
              lambda x : x in set(feature_sampler.targets.keys()),
              "target can only be one of {}".format(', '.join(list(feature_sampler.targets.keys())))
            )
          elif config.sample_corpus.corpus_config.active.HasField("active_learner"):
            active_models.AssertConfigIsValid(config.sample_corpus.corpus_config.active.active_learner)
          else:
            raise ValueError(config.sample_corpus.corpus_config.active)
        else:
          raise ValueError("Sampling type is undefined: {}".format(config.sample_corpus.corpus_config))

        pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "max_predictions_per_seq")
        pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "masked_lm_prob")

        pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "mask_technique")
        if config.sample_corpus.corpus_config.HasField("mask"):
          pbutil.AssertFieldIsSet(
            config.sample_corpus.corpus_config.mask,
            "random_placed_mask",
          )
        elif config.sample_corpus.corpus_config.HasField("hole"):
          if config.sample_corpus.corpus_config.hole.HasField("absolute_length"):
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.hole,
              "absolute_length",
              lambda x : x > 0,
              "absolute length is the upper bound range of a hole's length. Therefore should be > 0."
            )
          else:
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.hole,
              "relative_length",
              lambda x : 0.0 < x <= 1.0,
              "relative length must be between 0 and 100% of a kernel's actual length."
            )
          if config.sample_corpus.corpus_config.hole.HasField("normal_distribution"):
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.hole.normal_distribution,
              "mean",
            )
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.hole.normal_distribution,
              "variance",
            )
          elif not config.sample_corpus.corpus_config.hole.HasField("uniform_distribution"):
            raise ValueError("Hole length distribution has not been set.")
        elif config.sample_corpus.corpus_config.HasField("mask_seq"):
          if config.sample_corpus.corpus_config.mask_seq.HasField("absolute_length"):
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.mask_seq,
              "absolute_length",
              lambda x : x > 0,
              "absolute length is the upper bound range of a mask_seq's length. Therefore should be > 0."
            )
          else:
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.mask_seq,
              "relative_length",
              lambda x : 0.0 < x <= 1.0,
              "relative length must be between 0 and 100% of a kernel's actual length."
            )
          if config.sample_corpus.corpus_config.mask_seq.HasField("normal_distribution"):
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.mask_seq.normal_distribution,
              "mean",
            )
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.mask_seq.normal_distribution,
              "variance",
            )
          elif not config.sample_corpus.corpus_config.mask_seq.HasField("uniform_distribution"):
            raise ValueError("Hole length distribution has not been set.")
      else:
        raise ValueError("sample_corpus has no corpus_config field.")

      if config.sample_corpus.HasField("corpus"):
        corpuses.AssertConfigIsValid(config.sample_corpus.corpus)        
      else:
        pbutil.AssertFieldIsSet(
          config.sample_corpus,
          "start_text"
        )
    elif ((not config.HasField("train_set"))
      and (not config.HasField("validation_set"))
      and (not config.HasField("sample_set"))
      and (not config.HasField("live_sampling"))):
      raise ValueError(config)
    pbutil.AssertFieldConstraint(
      config, "batch_size", lambda x: 0 < x, "Sampler.batch_size must be > 0"
    )
    pbutil.AssertFieldConstraint(
      config,
      "sequence_length",
      lambda x: 0 < x,
      "Sampler.sequence_length must be > 0",
    )
    pbutil.AssertFieldConstraint(
      config,
      "temperature_micros",
      lambda x: 0 < x,
      "Sampler.temperature_micros must be > 0",
    )
    return config
  except pbutil.ProtoValueError as e:
    raise ValueError(e)
Example #2
0
  def __init__(self, config: sampler_pb2.Sampler, sample_db_name = "samples.db"):
    """Instantiate a sampler.

    Args:
      config: A Sampler message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: If the config contains invalid values.
    """
    if not isinstance(config, sampler_pb2.Sampler):
      t = type(config).__name__
      raise TypeError(f"Config must be a Sampler proto. Received: '{t}'")
    self.config = sampler_pb2.Sampler()
    self.config.CopyFrom(AssertConfigIsValid(config))
    self.hash = self._ComputeHash(self.config)
    self.terminators = GetTerminationCriteria(self.config.termination_criteria)
    if config.HasField("start_text"):
      self.start_text = self.config.start_text
    else:
      self.start_text = ""

    self.temperature = self.config.temperature_micros / 1e6
    self.batch_size = self.config.batch_size
    self.sequence_length = self.config.sequence_length
    self.sample_db_name = sample_db_name

    # Create the necessary cache directories.
    distrib.lock()
    self.cache = cache.mkcache("sampler", self.hash)
    distrib.unlock()
    self.samples_directory = self.cache.path / "samples"
    if environment.WORLD_RANK == 0:
      self.samples_directory.mkdir(exist_ok = True)
    self.corpus_directory = None
    self.sample_corpus    = None
    if self.config.HasField("sample_corpus"):
      self.corpus_directory = self.cache.path / "sample_corpus"
      if environment.WORLD_RANK == 0:
        self.corpus_directory.mkdir(exist_ok = True)
      if self.config.sample_corpus.HasField("corpus"):
        self.sample_corpus = corpuses.Corpus(self.config.sample_corpus.corpus)
        self.sample_corpus.Create()
        self.symlinkSampleCorpus(
          pathlib.Path(self.sample_corpus.encoded.url[len("sqlite:///") :]).parent
        )
        text_data = [
          self.sample_corpus.tokenizer.tokensToString(x) for x in self.sample_corpus.GetTrainingData()
        ]
      else:
        self.start_text = self.config.sample_corpus.start_text
        text_data = [self.start_text]
      # Text data is dumped in order to specialize with all different model tokenizers.
      if environment.WORLD_RANK == 0:
        with open(self.cache.path / "sample_corpus" / "text_corpus.pkl", 'wb') as outf:
          pickle.dump(text_data, outf)

    if self.has_active_learning:
      self.active_learner = active_models.Model(config.sample_corpus.corpus_config.active.active_learner)

    if environment.WORLD_RANK == 0:
      meta = internal_pb2.SamplerMeta()
      meta.config.CopyFrom(self.config)
      pbutil.ToFile(meta, path = self.cache.path / "META.pbtxt")
      commit.saveCommit(self.cache.path)

    # Set in Specialize().
    self.encoded_start_text = None
    self.tokenized_start_text = None