示例#1
0
  def SampleMaskLMBatchGenerator(self,
                                 model_opts              : model_pb2.TrainingOptions,
                                 sampler                 : "samplers.Sampler",
                                 tokenizer               : tokenizers.TokenizerBase,
                                 seed                    : int,
                                 sample_batch_size       : int,
                                 max_position_embeddings : int,
                                 cache_path              : pathlib.Path,
                                 feature_encoder         : bool                        = False,
                                 feature_tokenizer       : tokenizers.FeatureTokenizer = None,
                                 feature_sequence_length : int                         = None,
                                 ) -> "data_generator.MaskLMBatchGenerator":
    """Initializes data generator for inference."""
    self.cache                   = cache.mkcache(cache_path, "dataset")
    self.cache.path.mkdir(exist_ok = True, parents = True)

    self.dataset                 = {}
    self.sampler                 = sampler
    self.corpus                  = sampler.sample_corpus
    self.tokenizer               = tokenizer
    self.config                  = model_opts.data_generator
    self.rngen                   = np.random
    self.sample_batch_size       = sample_batch_size
    self.max_position_embeddings = max_position_embeddings

    self.feature_encoder         = feature_encoder
    self.feature_tokenizer       = feature_tokenizer
    self.feature_sequence_length = feature_sequence_length

    self.training_opts                 = model_opts
    self.training_opts.sequence_length = sampler.sequence_length
    self.training_opts.batch_size      = sampler.batch_size
    return self
示例#2
0
    def __init__(self, config: active_learning_pb2.ActiveLearner):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, active_learning_pb2.ActiveLearner):
            t = type(config).__name__
            raise TypeError(
                f"Config must be an ActiveLearner proto. Received: '{t}'")

        self.config = active_learning_pb2.ActiveLearner()
        # Validate config options.
        self.config.CopyFrom(AssertConfigIsValid(config))

        distrib.lock()
        self.cache = cache.mkcache("active_model")
        distrib.unlock()

        self.downstream_task = downstream_tasks.DownstreamTask.FromTask(
            self.config.downstream_task, self.config.training_corpus)

        if environment.WORLD_RANK == 0:
            ## Store current commit
            commit.saveCommit(self.cache.path)
        self.backend = active_committee.QueryByCommittee(
            self.config, self.cache, self.downstream_task)
        l.logger().info("Initialized {} in {}".format(self.backend,
                                                      self.cache.path))
        return
示例#3
0
  def TrainMaskLMBatchGenerator(self,
                                corpus                  : "corpuses.Corpus",
                                training_opts           : model_pb2.TrainingOptions,
                                cache_path              : pathlib.Path,
                                num_train_steps         : int                         = None,
                                pre_train               : bool                        = False,
                                feature_encoder         : bool                        = False,
                                feature_tokenizer       : tokenizers.FeatureTokenizer = None,
                                feature_sequence_length : int                         = None,
                                ) -> "data_generator.MaskLMDataGenerator":
    """Initializes data generator for training."""
    self.cache         = cache.mkcache(cache_path, "dataset")
    self.cache.path.mkdir(exist_ok = True, parents = True)

    self.dataset       = {}
    self.corpus        = corpus
    self.tokenizer     = corpus.tokenizer
    self.config        = training_opts.data_generator
    self.training_opts = training_opts
    self.rngen         = np.random # random.Random(training_opts.random_seed)
    self.pre_train     = pre_train
    self.feature_encoder         = feature_encoder
    self.feature_tokenizer       = feature_tokenizer
    self.feature_sequence_length = feature_sequence_length
    if num_train_steps:
      self.num_train_steps = num_train_steps
    else:
      self.num_train_steps = self.training_opts.num_train_steps
    shaped_corpus = self.createCorpus(self.cache.path)
    if self.config.datapoint_time == "pre":
      if self.feature_encoder:
        raise NotImplementedError("Pre masking corpus does not work with feature encoding model.")
      # 'pre' pre-processes/masks training/validation/sampling corpus for the model to use.
      # 'online' stores the raw data and masks them on the fly.
      self.configDataset(shaped_corpus)
    return self
示例#4
0
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")

        self.config = model_pb2.Model()
        # Validate config options.
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        if FLAGS.num_train_steps:
            self.config.training.num_train_steps = FLAGS.num_train_steps
        if FLAGS.num_pretrain_steps:
            self.config.training.num_pretrain_steps = FLAGS.num_pretrain_steps
        if FLAGS.num_epochs:
            self.config.training.num_epochs = FLAGS.num_epochs

        # Initialize distrib lock path.
        if environment.WORLD_SIZE > 1:
            if environment.WORLD_RANK == 0:
                lock_cache = cache.mkcache("locks")
                lock_cache.path.mkdir(exist_ok=True)
            else:
                while not cache.cachepath("locks").exists():
                    time.sleep(0.5)
                lock_cache = cache.mkcache("locks")
            distrib.init(lock_cache.path)

        # Initialize corpuses
        self.corpus = corpuses.Corpus(config.corpus)
        self.pre_train_corpus = None
        if config.HasField("pre_train_corpus"):
            self.pre_train_corpus = corpuses.Corpus(config.pre_train_corpus)

        self.hash = self._ComputeHash(self.pre_train_corpus, self.corpus,
                                      self.config)
        self._created = False

        distrib.lock()
        self.cache = cache.mkcache("model", self.hash)
        distrib.unlock()

        if environment.WORLD_RANK == 0:
            # Create the necessary cache directories.
            (self.cache.path / "checkpoints").mkdir(exist_ok=True)
            (self.cache.path / "samples").mkdir(exist_ok=True)
            # Create symlink to encoded corpus.
            symlink = self.cache.path / "corpus"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(
                        pathlib.Path(self.corpus.encoded.url[len("sqlite:///"
                                                                 ):]).parent,
                        self.cache.path,
                    ),
                    symlink,
                )
            if self.pre_train_corpus:
                symlink = self.cache.path / "pre_train_corpus"
                if not symlink.is_symlink():
                    os.symlink(
                        os.path.relpath(
                            pathlib.Path(self.pre_train_corpus.encoded.
                                         url[len("sqlite:///"):]).parent,
                            self.cache.path,
                        ),
                        symlink,
                    )

            # Create symlink to the tokenizer and create a backup inside checkpoints.
            symlink = self.cache.path / "tokenizer"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(self.corpus.tokenizer_path,
                                    self.cache.path), symlink)
            if (self.cache.path / "checkpoints" /
                    "backup_tokenizer.pkl").exists():
                shutil.copyfile(
                    self.cache.path / "checkpoints" / "backup_tokenizer.pkl",
                    self.corpus.tokenizer_path)

            # Validate metadata against cache.
            if self.cache.get("META.pbtxt"):
                cached_meta = pbutil.FromFile(
                    pathlib.Path(self.cache["META.pbtxt"]),
                    internal_pb2.ModelMeta())
                # Exclude num_epochs and corpus location from metadata comparison.
                config_to_compare = model_pb2.Model()
                config_to_compare.CopyFrom(self.config)
                config_to_compare.corpus.ClearField("contentfiles")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                config_to_compare.training.ClearField("num_epochs")
                config_to_compare.training.ClearField("num_train_steps")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.training.ClearField("num_pretrain_steps")
                config_to_compare.training.ClearField("batch_size")
                if config_to_compare.training.HasField("data_generator"):
                    config_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    config_to_compare.training.data_generator.ClearField(
                        "validation_set")
                # These fields should have already been cleared, but we'll do it again
                # so that metadata comparisons don't fail when the cached meta schema
                # is updated.
                cached_to_compare = model_pb2.Model()
                cached_to_compare.CopyFrom(cached_meta.config)
                cached_to_compare.corpus.ClearField("contentfiles")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                cached_to_compare.training.ClearField("num_epochs")
                cached_to_compare.training.ClearField("num_train_steps")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.training.ClearField("num_pretrain_steps")
                cached_to_compare.training.ClearField("batch_size")
                if cached_to_compare.training.HasField("data_generator"):
                    cached_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    cached_to_compare.training.data_generator.ClearField(
                        "validation_set")
                if cached_to_compare.training.sequence_length != config_to_compare.training.sequence_length:
                    l.logger().warning(
                        "Mismatch between pre-trained and current config sequence_length!\
            This can only be intended in BERT model!")
                cached_to_compare.training.ClearField("sequence_length")
                config_to_compare.training.ClearField("sequence_length")
                if config_to_compare != cached_to_compare:
                    raise SystemError("Metadata mismatch: {} \n\n {}".format(
                        config_to_compare, cached_to_compare))
                self.meta = cached_meta
            else:
                self.meta = internal_pb2.ModelMeta()
                self.meta.config.CopyFrom(self.config)
                self._WriteMetafile()

            ## Store current commit
            commit.saveCommit(self.cache.path)

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW_SEQ:
            tf_sequential.tfSequential,
            model_pb2.NetworkArchitecture.KERAS_SEQ:
            keras_sequential.kerasSequential,
            model_pb2.NetworkArchitecture.TENSORFLOW_BERT: tf_bert.tfBert,
            model_pb2.NetworkArchitecture.TORCH_BERT: torch_bert.torchBert,
        }[config.architecture.backend](self.config, self.cache, self.hash)
        l.logger().info("Initialized {} in {}".format(self.backend,
                                                      self.cache.path))
        return
示例#5
0
  def __init__(self, config: sampler_pb2.Sampler, sample_db_name = "samples.db"):
    """Instantiate a sampler.

    Args:
      config: A Sampler message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: If the config contains invalid values.
    """
    if not isinstance(config, sampler_pb2.Sampler):
      t = type(config).__name__
      raise TypeError(f"Config must be a Sampler proto. Received: '{t}'")
    self.config = sampler_pb2.Sampler()
    self.config.CopyFrom(AssertConfigIsValid(config))
    self.hash = self._ComputeHash(self.config)
    self.terminators = GetTerminationCriteria(self.config.termination_criteria)
    if config.HasField("start_text"):
      self.start_text = self.config.start_text
    else:
      self.start_text = ""

    self.temperature = self.config.temperature_micros / 1e6
    self.batch_size = self.config.batch_size
    self.sequence_length = self.config.sequence_length
    self.sample_db_name = sample_db_name

    # Create the necessary cache directories.
    distrib.lock()
    self.cache = cache.mkcache("sampler", self.hash)
    distrib.unlock()
    self.samples_directory = self.cache.path / "samples"
    if environment.WORLD_RANK == 0:
      self.samples_directory.mkdir(exist_ok = True)
    self.corpus_directory = None
    self.sample_corpus    = None
    if self.config.HasField("sample_corpus"):
      self.corpus_directory = self.cache.path / "sample_corpus"
      if environment.WORLD_RANK == 0:
        self.corpus_directory.mkdir(exist_ok = True)
      if self.config.sample_corpus.HasField("corpus"):
        self.sample_corpus = corpuses.Corpus(self.config.sample_corpus.corpus)
        self.sample_corpus.Create()
        self.symlinkSampleCorpus(
          pathlib.Path(self.sample_corpus.encoded.url[len("sqlite:///") :]).parent
        )
        text_data = [
          self.sample_corpus.tokenizer.tokensToString(x) for x in self.sample_corpus.GetTrainingData()
        ]
      else:
        self.start_text = self.config.sample_corpus.start_text
        text_data = [self.start_text]
      # Text data is dumped in order to specialize with all different model tokenizers.
      if environment.WORLD_RANK == 0:
        with open(self.cache.path / "sample_corpus" / "text_corpus.pkl", 'wb') as outf:
          pickle.dump(text_data, outf)

    if self.has_active_learning:
      self.active_learner = active_models.Model(config.sample_corpus.corpus_config.active.active_learner)

    if environment.WORLD_RANK == 0:
      meta = internal_pb2.SamplerMeta()
      meta.config.CopyFrom(self.config)
      pbutil.ToFile(meta, path = self.cache.path / "META.pbtxt")
      commit.saveCommit(self.cache.path)

    # Set in Specialize().
    self.encoded_start_text = None
    self.tokenized_start_text = None
示例#6
0
文件: corpuses.py 项目: fivosts/clgen
  def __init__(self, config: typing.Union[corpus_pb2.Corpus, corpus_pb2.PreTrainCorpus]):
    """Instantiate a corpus from a proto config.

    If this is a new corpus, a number of files will be created, which may
    take some time.

    Args:
      config: A Corpus message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: In case the corpus is not found, or config contains invalid
        options.
      EmptyCorpusException: In case the corpus contains no data.
    """
    if not isinstance(config, corpus_pb2.Corpus) and not isinstance(config, corpus_pb2.PreTrainCorpus):
      raise TypeError(f"Config must be a Corpus proto. Received: '{type(config).__name__}'")

    # Make a local copy of the configuration.
    if isinstance(config, corpus_pb2.Corpus):
      self.config    = corpus_pb2.Corpus()
      self.pre_train = False
    else:
      self.config    = corpus_pb2.PreTrainCorpus()
      self.pre_train = True

    self.config.CopyFrom(AssertConfigIsValid(config))
    self._tokenizer = None
    self._created = False

    # An in-memory cache of the encoded contentfiles indices arrays.
    # Set and used in GetTrainingData().
    self._indices_arrays: typing.Optional[typing.List[np.array]] = None

    if environment.WORLD_RANK == 0:
      cache.cachepath("corpus").mkdir(parents=True, exist_ok=True)
    distrib.barrier()
    self.content_id = ResolveContentId(self.config)
    # Database of pre-processed files.
    preprocessed_id = ResolvePreprocessedId(self.content_id, self.config)
    if environment.WORLD_RANK == 0:
      cache.cachepath("corpus", "preprocessed", preprocessed_id).mkdir(exist_ok=True, parents=True)
    distrib.barrier()
    preprocessed_db_path = cache.cachepath("corpus", "preprocessed",
                                           preprocessed_id, "preprocessed.db")

    if self.config.HasField("content_id") and not preprocessed_db_path.is_file():
      raise ValueError(f"Content ID not found: '{self.content_id}'")
    self.preprocessed = preprocessed.PreprocessedContentFiles(
      f"sqlite:///{preprocessed_db_path}"
    )
    # Create symlink to contentfiles.
    if environment.WORLD_RANK == 0:
      symlink = (pathlib.Path(self.preprocessed.url[len("sqlite:///") :]).parent / "contentfiles")
      if not symlink.is_symlink():
        if config.HasField("local_directory"):
          os.symlink(
            str(ExpandConfigPath(config.local_directory,   path_prefix=FLAGS.clgen_local_path_prefix)),
            symlink,
          )
        elif config.HasField("local_tar_archive"):
          os.symlink(
            str(ExpandConfigPath(config.local_tar_archive, path_prefix=FLAGS.clgen_local_path_prefix)),
            symlink,
          )
        elif config.HasField("bq_database"):
          os.symlink(
            str(ExpandConfigPath(config.bq_database, path_prefix=FLAGS.clgen_local_path_prefix)),
            symlink,
          )  
        # elif config.HasField("fetch_github"):
        #   os.symlink(
        #     str(ExpandConfigPath(config.fetch_github, path_prefix=FLAGS.clgen_local_path_prefix)),
        #     symlink,
        #   )
    distrib.barrier()
    # Data of encoded pre-preprocessed files.
    encoded_id = ResolveEncodedId(self.content_id, self.config)
    if environment.WORLD_RANK == 0:
      cache.cachepath("corpus", "encoded", encoded_id).mkdir(exist_ok=True, parents=True)
    distrib.barrier()
    db_path = cache.cachepath("corpus", "encoded", encoded_id, "encoded.db")
    if self.config.HasField("pre_encoded_corpus_url"):
      self.encoded = encoded.EncodedContentFiles(config.pre_encoded_corpus_url, self.pre_train)
    else:
      self.encoded = encoded.EncodedContentFiles(f"sqlite:///{db_path}", self.pre_train)
    self.tokenizer_path = cache.cachepath(
      "corpus", "encoded", encoded_id, "tokenizer.pkl"
    )
    if environment.WORLD_RANK == 0 and not self.config.HasField("pre_encoded_corpus_url"):
      symlink = (pathlib.Path(self.encoded.url[len("sqlite:///") :]).parent / "preprocessed")
      if not symlink.is_symlink():
        os.symlink(
          os.path.relpath(
            pathlib.Path(self.preprocessed.url[len("sqlite:///") :]).parent,
            pathlib.Path(self.encoded.url[len("sqlite:///") :]).parent,
            ),
          symlink,
        )
    self.hash = encoded_id
    self.cache = cache.mkcache("corpus", "encoded", encoded_id)
    if environment.WORLD_RANK == 0:
      commit.saveCommit(self.cache.path)
      commit.saveCommit(self.cache.path.parent.parent / "preprocessed" / preprocessed_id)
    distrib.barrier()
    l.logger().info("Initialized {}train corpus in {}".format("pre_" if self.pre_train else "", self.cache.path))
    return