Exemplo n.º 1
0
    def __init__(self, config: active_learning_pb2.ActiveLearner):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, active_learning_pb2.ActiveLearner):
            t = type(config).__name__
            raise TypeError(
                f"Config must be an ActiveLearner proto. Received: '{t}'")

        self.config = active_learning_pb2.ActiveLearner()
        # Validate config options.
        self.config.CopyFrom(AssertConfigIsValid(config))

        distrib.lock()
        self.cache = cache.mkcache("active_model")
        distrib.unlock()

        self.downstream_task = downstream_tasks.DownstreamTask.FromTask(
            self.config.downstream_task, self.config.training_corpus)

        if environment.WORLD_RANK == 0:
            ## Store current commit
            commit.saveCommit(self.cache.path)
        self.backend = active_committee.QueryByCommittee(
            self.config, self.cache, self.downstream_task)
        l.logger().info("Initialized {} in {}".format(self.backend,
                                                      self.cache.path))
        return
Exemplo n.º 2
0
  def __init__(self, url: str, is_pre_train: bool = False, must_exist: bool = False, is_replica = False):

    self.is_pre_train = is_pre_train
    if environment.WORLD_RANK == 0 or is_replica:
      encoded_path = pathlib.Path(url.replace("sqlite:///", "")).parent
      self.length_monitor   = monitors.CumulativeHistMonitor(encoded_path, "encoded_kernel_length")
      if not self.is_pre_train:
        self.token_monitor    = monitors.NormalizedFrequencyMonitor(encoded_path, "token_distribution")
        self.feature_monitors = {ftype: monitors.CategoricalDistribMonitor(encoded_path, "{}_distribution".format(ftype)) for ftype in extractor.extractors.keys()}
      super(EncodedContentFiles, self).__init__(url, Base, must_exist=must_exist)
    if environment.WORLD_SIZE > 1 and not is_replica:
      # Conduct engine connections to replicated preprocessed chunks.
      self.base_path = pathlib.Path(url.replace("sqlite:///", "")).resolve().parent
      hash_id = self.base_path.name
      try:
        tdir = pathlib.Path(FLAGS.local_filesystem).resolve() / hash_id / "node_encoded"
      except Exception:
        tdir = pathlib.Path("/tmp").resolve() / hash_id / "node_encoded"
      distrib.lock()
      tdir.mkdir(parents = True, exist_ok = True)
      distrib.unlock()
      self.replicated_path = tdir / "encoded_{}.db".format(environment.WORLD_RANK)
      self.replicated = EncodedContentFiles(
        url = "sqlite:///{}".format(str(self.replicated_path)),
        is_pre_train = is_pre_train,
        must_exist = must_exist,
        is_replica = True
      )
      self.length_monitor = self.replicated.length_monitor
      if not self.is_pre_train:
        self.token_monitor    = self.replicated.token_monitor
        self.feature_monitors = self.replicated.feature_monitors
      distrib.barrier()
    return
Exemplo n.º 3
0
  def __init__(
    self,
    path: pathlib.Path,
    must_exist: bool = False,
    flush_secs: int = 30,
    plot_sample_status = False,
    commit_sample_frequency: int = 1024,
  ):
    distrib.lock()
    self.db = samples_database.SamplesDatabase("sqlite:///{}".format(str(path)), must_exist = must_exist)
    distrib.unlock()
    self.sample_id   = self.db.count
    self.visited     = set(self.db.get_hash_entries)
    self.flush_queue = []
    self.plot_sample_status = plot_sample_status

    if self.plot_sample_status:
      self.saturation_monitor = monitors.CumulativeHistMonitor(path.parent, "cumulative_sample_count")
Exemplo n.º 4
0
 def __init__(self, url: str, must_exist: bool = False, is_replica = False):
   if environment.WORLD_RANK == 0 or is_replica:
     super(PreprocessedContentFiles, self).__init__(
       url, Base, must_exist=must_exist
     )
   if environment.WORLD_SIZE > 1 and not is_replica:
     # Conduct engine connections to replicated preprocessed chunks.
     self.base_path = pathlib.Path(url.replace("sqlite:///", "")).resolve().parent
     hash_id = self.base_path.name
     try:
       tdir = pathlib.Path(FLAGS.local_filesystem).resolve() / hash_id / "node_preprocessed"
     except Exception:
       tdir = pathlib.Path("/tmp").resolve() / hash_id / "node_preprocessed"
     distrib.lock()
     tdir.mkdir(parents = True, exist_ok = True)
     distrib.unlock()
     self.replicated_path = tdir / "preprocessed_{}.db".format(environment.WORLD_RANK)
     self.replicated = PreprocessedContentFiles(
       url = "sqlite:///{}".format(str(self.replicated_path)),
       must_exist = must_exist,
       is_replica = True
     )
     distrib.barrier()
   return
Exemplo n.º 5
0
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")

        self.config = model_pb2.Model()
        # Validate config options.
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        if FLAGS.num_train_steps:
            self.config.training.num_train_steps = FLAGS.num_train_steps
        if FLAGS.num_pretrain_steps:
            self.config.training.num_pretrain_steps = FLAGS.num_pretrain_steps
        if FLAGS.num_epochs:
            self.config.training.num_epochs = FLAGS.num_epochs

        # Initialize distrib lock path.
        if environment.WORLD_SIZE > 1:
            if environment.WORLD_RANK == 0:
                lock_cache = cache.mkcache("locks")
                lock_cache.path.mkdir(exist_ok=True)
            else:
                while not cache.cachepath("locks").exists():
                    time.sleep(0.5)
                lock_cache = cache.mkcache("locks")
            distrib.init(lock_cache.path)

        # Initialize corpuses
        self.corpus = corpuses.Corpus(config.corpus)
        self.pre_train_corpus = None
        if config.HasField("pre_train_corpus"):
            self.pre_train_corpus = corpuses.Corpus(config.pre_train_corpus)

        self.hash = self._ComputeHash(self.pre_train_corpus, self.corpus,
                                      self.config)
        self._created = False

        distrib.lock()
        self.cache = cache.mkcache("model", self.hash)
        distrib.unlock()

        if environment.WORLD_RANK == 0:
            # Create the necessary cache directories.
            (self.cache.path / "checkpoints").mkdir(exist_ok=True)
            (self.cache.path / "samples").mkdir(exist_ok=True)
            # Create symlink to encoded corpus.
            symlink = self.cache.path / "corpus"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(
                        pathlib.Path(self.corpus.encoded.url[len("sqlite:///"
                                                                 ):]).parent,
                        self.cache.path,
                    ),
                    symlink,
                )
            if self.pre_train_corpus:
                symlink = self.cache.path / "pre_train_corpus"
                if not symlink.is_symlink():
                    os.symlink(
                        os.path.relpath(
                            pathlib.Path(self.pre_train_corpus.encoded.
                                         url[len("sqlite:///"):]).parent,
                            self.cache.path,
                        ),
                        symlink,
                    )

            # Create symlink to the tokenizer and create a backup inside checkpoints.
            symlink = self.cache.path / "tokenizer"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(self.corpus.tokenizer_path,
                                    self.cache.path), symlink)
            if (self.cache.path / "checkpoints" /
                    "backup_tokenizer.pkl").exists():
                shutil.copyfile(
                    self.cache.path / "checkpoints" / "backup_tokenizer.pkl",
                    self.corpus.tokenizer_path)

            # Validate metadata against cache.
            if self.cache.get("META.pbtxt"):
                cached_meta = pbutil.FromFile(
                    pathlib.Path(self.cache["META.pbtxt"]),
                    internal_pb2.ModelMeta())
                # Exclude num_epochs and corpus location from metadata comparison.
                config_to_compare = model_pb2.Model()
                config_to_compare.CopyFrom(self.config)
                config_to_compare.corpus.ClearField("contentfiles")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                config_to_compare.training.ClearField("num_epochs")
                config_to_compare.training.ClearField("num_train_steps")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.training.ClearField("num_pretrain_steps")
                config_to_compare.training.ClearField("batch_size")
                if config_to_compare.training.HasField("data_generator"):
                    config_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    config_to_compare.training.data_generator.ClearField(
                        "validation_set")
                # These fields should have already been cleared, but we'll do it again
                # so that metadata comparisons don't fail when the cached meta schema
                # is updated.
                cached_to_compare = model_pb2.Model()
                cached_to_compare.CopyFrom(cached_meta.config)
                cached_to_compare.corpus.ClearField("contentfiles")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                cached_to_compare.training.ClearField("num_epochs")
                cached_to_compare.training.ClearField("num_train_steps")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.training.ClearField("num_pretrain_steps")
                cached_to_compare.training.ClearField("batch_size")
                if cached_to_compare.training.HasField("data_generator"):
                    cached_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    cached_to_compare.training.data_generator.ClearField(
                        "validation_set")
                if cached_to_compare.training.sequence_length != config_to_compare.training.sequence_length:
                    l.logger().warning(
                        "Mismatch between pre-trained and current config sequence_length!\
            This can only be intended in BERT model!")
                cached_to_compare.training.ClearField("sequence_length")
                config_to_compare.training.ClearField("sequence_length")
                if config_to_compare != cached_to_compare:
                    raise SystemError("Metadata mismatch: {} \n\n {}".format(
                        config_to_compare, cached_to_compare))
                self.meta = cached_meta
            else:
                self.meta = internal_pb2.ModelMeta()
                self.meta.config.CopyFrom(self.config)
                self._WriteMetafile()

            ## Store current commit
            commit.saveCommit(self.cache.path)

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW_SEQ:
            tf_sequential.tfSequential,
            model_pb2.NetworkArchitecture.KERAS_SEQ:
            keras_sequential.kerasSequential,
            model_pb2.NetworkArchitecture.TENSORFLOW_BERT: tf_bert.tfBert,
            model_pb2.NetworkArchitecture.TORCH_BERT: torch_bert.torchBert,
        }[config.architecture.backend](self.config, self.cache, self.hash)
        l.logger().info("Initialized {} in {}".format(self.backend,
                                                      self.cache.path))
        return
Exemplo n.º 6
0
  def __init__(self, config: sampler_pb2.Sampler, sample_db_name = "samples.db"):
    """Instantiate a sampler.

    Args:
      config: A Sampler message.

    Raises:
      TypeError: If the config argument is not a Sampler proto.
      UserError: If the config contains invalid values.
    """
    if not isinstance(config, sampler_pb2.Sampler):
      t = type(config).__name__
      raise TypeError(f"Config must be a Sampler proto. Received: '{t}'")
    self.config = sampler_pb2.Sampler()
    self.config.CopyFrom(AssertConfigIsValid(config))
    self.hash = self._ComputeHash(self.config)
    self.terminators = GetTerminationCriteria(self.config.termination_criteria)
    if config.HasField("start_text"):
      self.start_text = self.config.start_text
    else:
      self.start_text = ""

    self.temperature = self.config.temperature_micros / 1e6
    self.batch_size = self.config.batch_size
    self.sequence_length = self.config.sequence_length
    self.sample_db_name = sample_db_name

    # Create the necessary cache directories.
    distrib.lock()
    self.cache = cache.mkcache("sampler", self.hash)
    distrib.unlock()
    self.samples_directory = self.cache.path / "samples"
    if environment.WORLD_RANK == 0:
      self.samples_directory.mkdir(exist_ok = True)
    self.corpus_directory = None
    self.sample_corpus    = None
    if self.config.HasField("sample_corpus"):
      self.corpus_directory = self.cache.path / "sample_corpus"
      if environment.WORLD_RANK == 0:
        self.corpus_directory.mkdir(exist_ok = True)
      if self.config.sample_corpus.HasField("corpus"):
        self.sample_corpus = corpuses.Corpus(self.config.sample_corpus.corpus)
        self.sample_corpus.Create()
        self.symlinkSampleCorpus(
          pathlib.Path(self.sample_corpus.encoded.url[len("sqlite:///") :]).parent
        )
        text_data = [
          self.sample_corpus.tokenizer.tokensToString(x) for x in self.sample_corpus.GetTrainingData()
        ]
      else:
        self.start_text = self.config.sample_corpus.start_text
        text_data = [self.start_text]
      # Text data is dumped in order to specialize with all different model tokenizers.
      if environment.WORLD_RANK == 0:
        with open(self.cache.path / "sample_corpus" / "text_corpus.pkl", 'wb') as outf:
          pickle.dump(text_data, outf)

    if self.has_active_learning:
      self.active_learner = active_models.Model(config.sample_corpus.corpus_config.active.active_learner)

    if environment.WORLD_RANK == 0:
      meta = internal_pb2.SamplerMeta()
      meta.config.CopyFrom(self.config)
      pbutil.ToFile(meta, path = self.cache.path / "META.pbtxt")
      commit.saveCommit(self.cache.path)

    # Set in Specialize().
    self.encoded_start_text = None
    self.tokenized_start_text = None