Exemplo n.º 1
0
def AssertConfigIsValid(config: typing.Union[corpus_pb2.Corpus, corpus_pb2.PreTrainCorpus]
                       ) -> typing.Union[corpus_pb2.Corpus, corpus_pb2.PreTrainCorpus]:
  """Assert that config proto is valid.

  Args:
    config: A Corpus proto.

  Returns:
    The Corpus proto.

  Raises:
    UserError: If the config is invalid.
  """
  try:
    # Early-exit to support corpuses derived from databases of pre-encoded
    # content files.
    # TODO(github.com/ChrisCummins/clgen/issues/130): Refactor after splitting
    # Corpus class.
    if config.HasField("pre_encoded_corpus_url"):
      return config

    pbutil.AssertFieldIsSet(config,          "contentfiles")
    if isinstance(config, corpus_pb2.Corpus):
      pbutil.AssertFieldIsSet(config,          "tokenizer")
      pbutil.AssertFieldIsSet(config.tokenizer, "token_type")
      pbutil.AssertFieldConstraint(config.tokenizer, 
                                   "token_type", 
                                   lambda x: x == "character" or x == "word" or x == "ast",
                                   "tokenizer is either character or word based."
                                   )
      if config.tokenizer.token_type == "word":
        pbutil.AssertFieldConstraint(config.tokenizer,
                                    "token_list",
                                    lambda x: os.path.isfile(str(ExpandConfigPath(x, path_prefix=FLAGS.clgen_local_path_prefix))),
                                    "Invalid token_list file"
                                    )
    else:
      if config.HasField("tokenizer"):
        raise ValueError("Pre-train corpus cannot have a distinct tokenizer.")
    pbutil.AssertFieldIsSet(config,          "contentfile_separator")
    # Check that the preprocessor pipeline resolves to preprocessor functions.
    [preprocessors.GetPreprocessorFunction(p) for p in config.preprocessor]

    return config
  except pbutil.ProtoValueError as e:
    raise e
Exemplo n.º 2
0
 def __init__(self, config: sampler_pb2.SymmetricalTokenDepth):
   try:
     self.left_token = pbutil.AssertFieldConstraint(
       config,
       "depth_increase_token",
       lambda s: len(s) > 0,
       "SymmetricalTokenDepth.depth_increase_token must be a string",
     )
     self.right_token = pbutil.AssertFieldConstraint(
       config,
       "depth_decrease_token",
       lambda s: len(s) > 0,
       "SymmetricalTokenDepth.depth_decrease_token must be a string",
     )
   except pbutil.ProtoValueError as e:
     raise ValueError(e)
   if self.left_token == self.right_token:
     raise ValueError("SymmetricalTokenDepth tokens must be different")
Exemplo n.º 3
0
 def __init__(self, config: sampler_pb2.MaxTokenLength):
   try:
     self.max_len = pbutil.AssertFieldConstraint(
       config,
       "maximum_tokens_in_sample",
       lambda x: x > 1,
       "MaxTokenLength.maximum_tokens_in_sample must be > 0",
     )
   except pbutil.ProtoValueError as e:
     raise ValueError(e)
Exemplo n.º 4
0
    def FromConfig(cls, config: github_pb2.GithubMiner):
        """Constructs github miner from protobuf configuration."""
        try:
            pbutil.AssertFieldIsSet(config, "path")
            pbutil.AssertFieldIsSet(config, "data_format")
            pbutil.AssertFieldIsSet(config, "miner")

            if config.HasField("big_query"):
                pbutil.AssertFieldIsSet(config.big_query, "credentials")
                pbutil.AssertFieldConstraint(
                    config.big_query,
                    "language",
                    lambda x: x in
                    {'generic', 'opencl', 'c', 'cpp', 'java', 'python'},
                    "language must be one of opencl, c, cpp, java, python. 'generic' for language agnostic queries.",
                )
                if config.big_query.HasField("export_corpus"):
                    pbutil.AssertFieldIsSet(config.big_query.export_corpus,
                                            "data_format")
                    pbutil.AssertFieldIsSet(config.big_query.export_corpus,
                                            "access_token")
                return BigQuery(config)
            elif config.HasField("recursive"):
                pbutil.AssertFieldIsSet(config.recursive, "access_token")
                pbutil.AssertFieldConstraint(
                    config.recursive, "flush_limit_K", lambda x: x > 0,
                    "flush limit cannot be non-positive.")
                pbutil.AssertFieldConstraint(
                    config.recursive, "corpus_size_K", lambda x: x >= -1,
                    "corpus size must either be -1 or non-negative.")
                if config.data_format != github_pb2.GithubMiner.DataFormat.folder:
                    raise NotImplementedError(
                        "RecursiveFetcher only stores files in local folder.")
                return RecursiveFetcher(config)
            else:
                raise SystemError("{} miner not recognized".format(config))
        except Exception as e:
            raise e
Exemplo n.º 5
0
def AssertConfigIsValid(
    config: active_learning_pb2.ActiveLearner
) -> active_learning_pb2.ActiveLearner:
    """
  Parse proto description and check for validity.
  """
    pbutil.AssertFieldConstraint(
        config, "downstream_task", lambda x: x in downstream_tasks.TASKS,
        "Downstream task has to be one of {}".format(', '.join(
            [str(x) for x in downstream_tasks.TASKS])))
    if config.HasField("committee"):
        com_config.AssertConfigIsValid(config)
    else:
        raise NotImplementedError(config)
    return config
Exemplo n.º 6
0
def AssertIsBuildable(config: model_pb2.Model) -> model_pb2.Model:
    """Assert that a model configuration is buildable.

  Args:
    config: A model proto.

  Returns:
    The input model proto, unmodified.

  Raises:
    UserError: If the model is not buildable.
    InternalError: If the value of the training.optimizer field is not
      understood.
  """
    # Any change to the Model proto schema will require a change to this function.
    try:
        pbutil.AssertFieldIsSet(config, "corpus")
        pbutil.AssertFieldIsSet(config, "architecture")
        pbutil.AssertFieldIsSet(config, "training")
        pbutil.AssertFieldIsSet(config.architecture, "backend")
        if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS_SEQ:
            pbutil.AssertFieldIsSet(config.architecture, "neuron_type")
            pbutil.AssertFieldConstraint(
                config.architecture,
                "embedding_size",
                lambda x: 0 < x,
                "NetworkArchitecture.embedding_size must be > 0",
            )
        elif config.architecture.backend == model_pb2.NetworkArchitecture.TENSORFLOW_SEQ:
            pbutil.AssertFieldIsSet(config.architecture, "neuron_type")
            pbutil.AssertFieldConstraint(
                config.architecture,
                "neurons_per_layer",
                lambda x: 0 < x,
                "NetworkArchitecture.neurons_per_layer must be > 0",
            )
            pbutil.AssertFieldConstraint(
                config.architecture,
                "num_layers",
                lambda x: 0 < x,
                "NetworkArchitecture.num_layers must be > 0",
            )
            pbutil.AssertFieldConstraint(
                config.architecture,
                "post_layer_dropout_micros",
                lambda x: 0 <= x <= 1000000,
                "NetworkArchitecture.post_layer_dropout_micros "
                "must be >= 0 and <= 1000000",
            )
            pbutil.AssertFieldConstraint(
                config.training,
                "num_epochs",
                lambda x: 0 < x,
                "TrainingOptions.num_epochs must be > 0",
            )
        elif config.architecture.backend == model_pb2.NetworkArchitecture.TENSORFLOW_BERT\
          or config.architecture.backend == model_pb2.NetworkArchitecture.TORCH_BERT:
            # Data generator is needed when using bert.
            pbutil.AssertFieldIsSet(config.training, "data_generator")
            # Parse data_generator params.
            _ = lm_data_generator.AssertConfigIsValid(
                config.training.data_generator)
            ## .architecture params
            pbutil.AssertFieldIsSet(
                config.architecture,
                "hidden_size",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "num_hidden_layers",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "num_attention_heads",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "intermediate_size",
            )
            pbutil.AssertFieldConstraint(
                config.architecture, "hidden_size",
                lambda x: x % config.architecture.num_attention_heads == 0,
                "The hidden size is not a multiple of the number of attention "
                "heads.")
            pbutil.AssertFieldIsSet(
                config.architecture,
                "hidden_act",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "hidden_dropout_prob",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "attention_probs_dropout_prob",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "type_vocab_size",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "initializer_range",
            )
            pbutil.AssertFieldIsSet(
                config.architecture,
                "layer_norm_eps",
            )
            ## Optional feature encoder attributes
            if config.architecture.HasField(
                    "feature_encoder"
            ) and config.architecture.feature_encoder == True:
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_sequence_length")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_embedding_size")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_dropout_prob")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_singular_token_thr")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_max_value_token")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_token_range")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_num_attention_heads")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_transformer_feedforward")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_layer_norm_eps")
                pbutil.AssertFieldIsSet(config.architecture,
                                        "feature_num_hidden_layers")
            ## .training params
            pbutil.AssertFieldIsSet(
                config.training,
                "max_predictions_per_seq",
            )
            pbutil.AssertFieldIsSet(
                config.training,
                "num_train_steps",
            )
            pbutil.AssertFieldIsSet(
                config.training,
                "num_warmup_steps",
            )
            if config.HasField("pre_train_corpus"):
                pbutil.AssertFieldIsSet(
                    config.training,
                    "num_pretrain_steps",
                )
                pbutil.AssertFieldIsSet(
                    config.training,
                    "num_prewarmup_steps",
                )
            pbutil.AssertFieldIsSet(
                config.training,
                "dupe_factor",
            )
            pbutil.AssertFieldIsSet(
                config.training,
                "masked_lm_prob",
            )
            pbutil.AssertFieldConstraint(
                config.training,
                "random_seed",
                lambda x: 0 <= x,
                "TrainingOptions.random_seed must be >= 0",
            )

        pbutil.AssertFieldConstraint(
            config.training,
            "sequence_length",
            lambda x: 1 <= x,
            "TrainingOptions.sequence_length must be >= 1",
        )
        pbutil.AssertFieldIsSet(config.training,
                                "shuffle_corpus_contentfiles_between_epochs")
        pbutil.AssertFieldConstraint(
            config.training,
            "batch_size",
            lambda x: 0 < x,
            "TrainingOptions.batch_size must be > 0",
        )
        pbutil.AssertFieldIsSet(config.training, "optimizer")
        if config.training.HasField("adam_optimizer"):
            pbutil.AssertFieldConstraint(
                config.training.adam_optimizer,
                "initial_learning_rate_micros",
                lambda x: 0 <= x,
                "AdamOptimizer.initial_learning_rate_micros must be >= 0",
            )
            if config.architecture.backend == model_pb2.NetworkArchitecture.KERAS_SEQ or \
               config.architecture.backend == model_pb2.NetworkArchitecture.TENSORFLOW_SEQ:
                pbutil.AssertFieldConstraint(
                    config.training.adam_optimizer,
                    "learning_rate_decay_per_epoch_micros",
                    lambda x: 0 <= x,
                    "AdamOptimizer.learning_rate_decay_per_epoch_micros must be >= 0",
                )
                pbutil.AssertFieldConstraint(
                    config.training.adam_optimizer,
                    "beta_1_micros",
                    lambda x: 0 <= x <= 1000000,
                    "AdamOptimizer.beta_1_micros must be >= 0 and <= 1000000",
                )
                pbutil.AssertFieldConstraint(
                    config.training.adam_optimizer,
                    "beta_2_micros",
                    lambda x: 0 <= x <= 1000000,
                    "AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000",
                )
                pbutil.AssertFieldConstraint(
                    config.training.adam_optimizer,
                    "normalized_gradient_clip_micros",
                    lambda x: 0 <= x,
                    "AdamOptimizer.normalized_gradient_clip_micros must be >= 0",
                )
        elif config.training.HasField("rmsprop_optimizer"):
            pbutil.AssertFieldConstraint(
                config.training.rmsprop_optimizer,
                "initial_learning_rate_micros",
                lambda x: 0 <= x,
                "RmsPropOptimizer.initial_learning_rate_micros must be >= 0",
            )
            pbutil.AssertFieldConstraint(
                config.training.rmsprop_optimizer,
                "learning_rate_decay_per_epoch_micros",
                lambda x: 0 <= x,
                "RmsPropOptimizer.learning_rate_decay_per_epoch_micros must be >= 0",
            )
        else:
            raise SystemError(
                "Unrecognized value: 'TrainingOptions.optimizer'")
    except Exception as e:
        raise e
    return config
Exemplo n.º 7
0
def AssertConfigIsValid(config: model_pb2.DataGenerator) -> model_pb2.DataGenerator:
  """
  Parse data generator protobuf message.
  Raise Exception if format is wrong.
  """
  pbutil.AssertFieldConstraint(
    config,
    "datapoint_type",
    lambda x: x == "kernel" or x == "statement",
    "Valid options for datapoint_type are 'kernel' and 'statement'",
  )
  pbutil.AssertFieldConstraint(
    config,
    "datapoint_time",
    lambda x: x == "online" or x == "pre",
    "Valid options for datapoint_time are 'online' and 'pre'",
  )
  pbutil.AssertFieldIsSet(
    config,
    "use_start_end",
  )
  pbutil.AssertFieldIsSet(
    config,
    "steps_per_epoch",
  )
  pbutil.AssertFieldConstraint(
    config,
    "validation_split",
    lambda x : 0 <= x <= 100,
    "Validation split is expressed in [0-100]%."
  )
  if config.datapoint_type == "kernel":
    pbutil.AssertFieldIsSet(
      config,
      "truncate_large_kernels",
    )
  if len(config.validation_set) > 0:
    for val_opt in config.validation_set:
      if val_opt.HasField("mask"):
        pbutil.AssertFieldIsSet(
          val_opt.mask,
          "random_placed_mask",
        )
      elif val_opt.HasField("hole"):
        if val_opt.HasField("absolute_length"):
          pbutil.AssertFieldConstraint(
            val_opt.hole,
            "absolute_length",
            lambda x : x > 0,
            "absolute length is the upper bound range of a hole's length. Therefore should be > 0."
          )
        else:
          pbutil.AssertFieldConstraint(
            val_opt.hole,
            "relative_length",
            lambda x : 0.0 < x <= 1.0,
            "relative length must be between 0 and 100% of a kernel's actual length."
          )
        if val_opt.hole.HasField("normal_distribution"):
          pbutil.AssertFieldIsSet(
            val_opt.hole.normal_distribution,
            "mean",
          )
          pbutil.AssertFieldIsSet(
            val_opt.hole.normal_distribution,
            "variance",
          )
        elif not val_opt.hole.HasField("uniform_distribution"):
          raise ValueError("Hole length distribution has not been set.")
      elif val_opt.HasField("mask_seq"):
        if val_opt.HasField("absolute_length"):
          pbutil.AssertFieldConstraint(
            val_opt.mask_seq,
            "absolute_length",
            lambda x : x > 0,
            "absolute length is the upper bound range of a mask_seq's length. Therefore should be > 0."
          )
        else:
          pbutil.AssertFieldConstraint(
            val_opt.mask_seq,
            "relative_length",
            lambda x : 0.0 < x <= 1.0,
            "relative length must be between 0 and 100% of a kernel's actual length."
          )
        if val_opt.mask_seq.HasField("normal_distribution"):
          pbutil.AssertFieldIsSet(
            val_opt.mask_seq.normal_distribution,
            "mean",
          )
          pbutil.AssertFieldIsSet(
            val_opt.mask_seq.normal_distribution,
            "variance",
          )
        elif not val_opt.mask_seq.HasField("uniform_distribution"):
          raise ValueError("Hole length distribution has not been set.")
  # Parse masking technique for bert's data generator
  pbutil.AssertFieldIsSet(config, "mask_technique")
  if config.HasField("mask"):
    pbutil.AssertFieldIsSet(
      config.mask,
      "random_placed_mask",
    )
  elif config.HasField("hole"):
    if config.hole.HasField("absolute_length"):
      pbutil.AssertFieldConstraint(
        config.hole,
        "absolute_length",
        lambda x : x > 0,
        "absolute length is the upper bound range of a hole's length. Therefore should be > 0."
      )
    else:
      pbutil.AssertFieldConstraint(
        config.hole,
        "relative_length",
        lambda x : 0.0 < x <= 1.0,
        "relative length must be between 0 and 100% of a kernel's actual length."
      )
    if config.hole.HasField("normal_distribution"):
      pbutil.AssertFieldIsSet(
        config.hole.normal_distribution,
        "mean",
      )
      pbutil.AssertFieldIsSet(
        config.hole.normal_distribution,
        "variance",
      )
    elif not config.hole.HasField("uniform_distribution"):
      raise ValueError("Hole length distribution has not been set.")
    pbutil.AssertFieldIsSet(
      config.hole,
      "stage_training",
    )
  elif config.HasField("mask_seq"):
    if config.mask_seq.HasField("absolute_length"):
      pbutil.AssertFieldConstraint(
        config.mask_seq,
        "absolute_length",
        lambda x : x > 0,
        "absolute length is the upper bound range of a mask_seq's length. Therefore should be > 0."
      )
    else:
      pbutil.AssertFieldConstraint(
        config.mask_seq,
        "relative_length",
        lambda x : 0.0 < x <= 1.0,
        "relative length must be between 0 and 100% of a kernel's actual length."
      )
    if config.mask_seq.HasField("normal_distribution"):
      pbutil.AssertFieldIsSet(
        config.mask_seq.normal_distribution,
        "mean",
      )
      pbutil.AssertFieldIsSet(
        config.mask_seq.normal_distribution,
        "variance",
      )
    elif not config.mask_seq.HasField("uniform_distribution"):
      raise ValueError("Hole length distribution has not been set.")
    pbutil.AssertFieldIsSet(
      config.mask_seq,
      "stage_training",
    )
  return config
Exemplo n.º 8
0
def AssertConfigIsValid(config: sampler_pb2.Sampler) -> sampler_pb2.Sampler:
  """Assert that a sampler configuration contains no invalid values.

  Args:
    config: A sampler configuration proto.

  Returns:
    The sampler configuration proto.

  Raises:
    UserError: If there are configuration errors.
  """
  try:
    if config.HasField("start_text"):
      pbutil.AssertFieldConstraint(
        config,
        "start_text",
        lambda s: len(s),
        "Sampler.start_text must be a string",
      )
    elif config.HasField("sample_corpus"):
      if config.sample_corpus.HasField("corpus_config"):
        if config.sample_corpus.corpus_config.HasField("normal"):
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "normal")
        elif config.sample_corpus.corpus_config.HasField("online"):
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "online")
        elif config.sample_corpus.corpus_config.HasField("active"):
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_limit_per_feed")
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_search_depth")
          pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config.active, "active_search_width")
          pbutil.AssertFieldConstraint(
            config.sample_corpus.corpus_config.active,
            "batch_size_per_feed",
            lambda x : config.batch_size % x == 0,
            "batch_size {} must be a multiple of batch_size_per_feed".format(
              config.sample_corpus.corpus_config.active,
              config.batch_size
            )
          )
          pbutil.AssertFieldConstraint(
            config.sample_corpus.corpus_config.active,
            "feature_space",
            lambda x : x in set(extractor.extractors.keys()),
            "feature_space can only be one of {}".format(', '.join(list(extractor.extractors.keys())))
          )
          if config.sample_corpus.corpus_config.active.HasField("target"):
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.active,
              "target",
              lambda x : x in set(feature_sampler.targets.keys()),
              "target can only be one of {}".format(', '.join(list(feature_sampler.targets.keys())))
            )
          elif config.sample_corpus.corpus_config.active.HasField("active_learner"):
            active_models.AssertConfigIsValid(config.sample_corpus.corpus_config.active.active_learner)
          else:
            raise ValueError(config.sample_corpus.corpus_config.active)
        else:
          raise ValueError("Sampling type is undefined: {}".format(config.sample_corpus.corpus_config))

        pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "max_predictions_per_seq")
        pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "masked_lm_prob")

        pbutil.AssertFieldIsSet(config.sample_corpus.corpus_config, "mask_technique")
        if config.sample_corpus.corpus_config.HasField("mask"):
          pbutil.AssertFieldIsSet(
            config.sample_corpus.corpus_config.mask,
            "random_placed_mask",
          )
        elif config.sample_corpus.corpus_config.HasField("hole"):
          if config.sample_corpus.corpus_config.hole.HasField("absolute_length"):
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.hole,
              "absolute_length",
              lambda x : x > 0,
              "absolute length is the upper bound range of a hole's length. Therefore should be > 0."
            )
          else:
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.hole,
              "relative_length",
              lambda x : 0.0 < x <= 1.0,
              "relative length must be between 0 and 100% of a kernel's actual length."
            )
          if config.sample_corpus.corpus_config.hole.HasField("normal_distribution"):
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.hole.normal_distribution,
              "mean",
            )
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.hole.normal_distribution,
              "variance",
            )
          elif not config.sample_corpus.corpus_config.hole.HasField("uniform_distribution"):
            raise ValueError("Hole length distribution has not been set.")
        elif config.sample_corpus.corpus_config.HasField("mask_seq"):
          if config.sample_corpus.corpus_config.mask_seq.HasField("absolute_length"):
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.mask_seq,
              "absolute_length",
              lambda x : x > 0,
              "absolute length is the upper bound range of a mask_seq's length. Therefore should be > 0."
            )
          else:
            pbutil.AssertFieldConstraint(
              config.sample_corpus.corpus_config.mask_seq,
              "relative_length",
              lambda x : 0.0 < x <= 1.0,
              "relative length must be between 0 and 100% of a kernel's actual length."
            )
          if config.sample_corpus.corpus_config.mask_seq.HasField("normal_distribution"):
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.mask_seq.normal_distribution,
              "mean",
            )
            pbutil.AssertFieldIsSet(
              config.sample_corpus.corpus_config.mask_seq.normal_distribution,
              "variance",
            )
          elif not config.sample_corpus.corpus_config.mask_seq.HasField("uniform_distribution"):
            raise ValueError("Hole length distribution has not been set.")
      else:
        raise ValueError("sample_corpus has no corpus_config field.")

      if config.sample_corpus.HasField("corpus"):
        corpuses.AssertConfigIsValid(config.sample_corpus.corpus)        
      else:
        pbutil.AssertFieldIsSet(
          config.sample_corpus,
          "start_text"
        )
    elif ((not config.HasField("train_set"))
      and (not config.HasField("validation_set"))
      and (not config.HasField("sample_set"))
      and (not config.HasField("live_sampling"))):
      raise ValueError(config)
    pbutil.AssertFieldConstraint(
      config, "batch_size", lambda x: 0 < x, "Sampler.batch_size must be > 0"
    )
    pbutil.AssertFieldConstraint(
      config,
      "sequence_length",
      lambda x: 0 < x,
      "Sampler.sequence_length must be > 0",
    )
    pbutil.AssertFieldConstraint(
      config,
      "temperature_micros",
      lambda x: 0 < x,
      "Sampler.temperature_micros must be > 0",
    )
    return config
  except pbutil.ProtoValueError as e:
    raise ValueError(e)