Exemple #1
0
def LsModels(cache_root: pathlib.Path) -> None:
    for model_dir in (cache_root / "model").iterdir():
        meta_file = model_dir / "META.pbtxt"
        if pbutil.ProtoIsReadable(meta_file, internal_pb2.ModelMeta()):
            model = models.Model(
                pbutil.FromFile(meta_file, internal_pb2.ModelMeta()).config)
            telemetry = list(model.TrainingTelemetry())
            num_epochs = model.config.training.num_epochs
            n = len(telemetry)
            print(f"{model_dir} {n} / {num_epochs} epochs")
        elif meta_file.is_file():
            app.Warning("Meta file %s cannot be read.", meta_file)
        else:
            app.Warning("Meta file %s not found.", meta_file)
Exemple #2
0
def parseModelSummary(meta):
    m = pbutil.FromString('\n'.join(meta), internal_pb2.ModelMeta())
    if m.config.architecture.backend == model_pb2.NetworkArchitecture.TENSORFLOW_BERT:
        summary = (
            "BERT, hs: {}, nhl: {}, atth: {}, imsz: {}, pemb: {}, preds: {}, dp: {}, mprob: {}, {}"
            .format(
                m.config.architecture.hidden_size,
                m.config.architecture.num_hidden_layers,
                m.config.architecture.num_attention_heads,
                m.config.architecture.intermediate_size,
                m.config.architecture.max_position_embeddings,
                m.config.training.max_predictions_per_seq,
                m.config.training.dupe_factor,
                round(m.config.training.masked_lm_prob, 3),
                "mask" if m.config.training.data_generator.HasField("mask")
                else "hole-{},{}".format(
                    (m.config.training.data_generator.hole.absolute
                     if m.config.training.data_generator.hole.HasField(
                         "absolute_length") else
                     m.config.training.data_generator.hole.relative_length),
                    "unf" if m.config.training.data_generator.hole.HasField(
                        "uniform_distribution") else "norm-{},{}".format(
                            round(
                                m.config.training.data_generator.hole.
                                normal_distribution.mean, 2),
                            round(
                                m.config.training.data_generator.hole.
                                normal_distribution.variance, 2)))))
    else:
        raise NotImplementedError
    return summary
Exemple #3
0
def test_Model_metafile(clgen_cache_dir, abc_model_config):
    """A newly instantiated model's cache has a metafile."""
    del clgen_cache_dir
    m = models.Model(abc_model_config)
    assert (m.cache.path / 'META.pbtxt').is_file()
    assert pbutil.ProtoIsReadable(m.cache.path / 'META.pbtxt',
                                  internal_pb2.ModelMeta())
Exemple #4
0
 def __init__(self, path: pathlib.Path):
   self.path = path.absolute()
   self.cache = cache.FSCache(self.path)
   self.corpus = NullCorpus()
   self.config = pbutil.FromFile(
       self.path / 'META.pbtxt', internal_pb2.ModelMeta()).config
   self.atomizer = atomizers.AtomizerBase.FromFile(self.path / 'atomizer')
   self.backend = {
     model_pb2.NetworkArchitecture.TENSORFLOW: tensorflow_backend.TensorFlowBackend,
     model_pb2.NetworkArchitecture.KERAS: keras_backend.KerasBackend,
   }[self.config.architecture.backend](self.config, self.cache, self.atomizer)
Exemple #5
0
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")
        # Validate config options.
        if config.training.sequence_length < 1:
            raise errors.UserError(
                'TrainingOptions.sequence_length must be >= 1')

        self.config = model_pb2.Model()
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        self.corpus = corpuses.Corpus(config.corpus)
        self.hash = self._ComputeHash(self.corpus, self.config)
        self.cache = cache.mkcache('model', self.hash)
        # Create the necessary cache directories.
        (self.cache.path / 'checkpoints').mkdir(exist_ok=True)
        (self.cache.path / 'samples').mkdir(exist_ok=True)
        (self.cache.path / 'logs').mkdir(exist_ok=True)

        # Create symlink to encoded corpus.
        symlink = self.cache.path / 'corpus'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(
                    pathlib.Path(
                        self.corpus.encoded.url[len('sqlite:///'):]).parent,
                    self.cache.path), symlink)

        # Create symlink to the atomizer.
        symlink = self.cache.path / 'atomizer'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(self.corpus.atomizer_path, self.cache.path),
                symlink)

        # Validate metadata against cache.
        if self.cache.get('META.pbtxt'):
            cached_meta = pbutil.FromFile(
                pathlib.Path(self.cache['META.pbtxt']),
                internal_pb2.ModelMeta())
            # Exclude num_epochs and corpus location from metadata comparison.
            config_to_compare = model_pb2.Model()
            config_to_compare.CopyFrom(self.config)
            config_to_compare.corpus.ClearField('contentfiles')
            config_to_compare.training.ClearField('num_epochs')
            # These fields should have already been cleared, but we'll do it again
            # so that metadata comparisons don't fail when the cached meta schema
            # is updated.
            cached_to_compare = model_pb2.Model()
            cached_to_compare.CopyFrom(cached_meta.config)
            cached_to_compare.corpus.ClearField('contentfiles')
            cached_to_compare.training.ClearField('num_epochs')
            if config_to_compare != cached_to_compare:
                raise errors.InternalError('Metadata mismatch')
            self.meta = cached_meta
        else:
            self.meta = internal_pb2.ModelMeta()
            self.meta.config.CopyFrom(self.config)
            self._WriteMetafile()

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW:
            tensorflow_backend.TensorFlowBackend,
            model_pb2.NetworkArchitecture.KERAS: keras_backend.KerasBackend,
        }[config.architecture.backend](self.config, self.cache, self.corpus)
Exemple #6
0
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")

        self.config = model_pb2.Model()
        # Validate config options.
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        if FLAGS.num_train_steps:
            self.config.training.num_train_steps = FLAGS.num_train_steps
        if FLAGS.num_pretrain_steps:
            self.config.training.num_pretrain_steps = FLAGS.num_pretrain_steps
        if FLAGS.num_epochs:
            self.config.training.num_epochs = FLAGS.num_epochs

        # Initialize distrib lock path.
        if environment.WORLD_SIZE > 1:
            if environment.WORLD_RANK == 0:
                lock_cache = cache.mkcache("locks")
                lock_cache.path.mkdir(exist_ok=True)
            else:
                while not cache.cachepath("locks").exists():
                    time.sleep(0.5)
                lock_cache = cache.mkcache("locks")
            distrib.init(lock_cache.path)

        # Initialize corpuses
        self.corpus = corpuses.Corpus(config.corpus)
        self.pre_train_corpus = None
        if config.HasField("pre_train_corpus"):
            self.pre_train_corpus = corpuses.Corpus(config.pre_train_corpus)

        self.hash = self._ComputeHash(self.pre_train_corpus, self.corpus,
                                      self.config)
        self._created = False

        distrib.lock()
        self.cache = cache.mkcache("model", self.hash)
        distrib.unlock()

        if environment.WORLD_RANK == 0:
            # Create the necessary cache directories.
            (self.cache.path / "checkpoints").mkdir(exist_ok=True)
            (self.cache.path / "samples").mkdir(exist_ok=True)
            # Create symlink to encoded corpus.
            symlink = self.cache.path / "corpus"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(
                        pathlib.Path(self.corpus.encoded.url[len("sqlite:///"
                                                                 ):]).parent,
                        self.cache.path,
                    ),
                    symlink,
                )
            if self.pre_train_corpus:
                symlink = self.cache.path / "pre_train_corpus"
                if not symlink.is_symlink():
                    os.symlink(
                        os.path.relpath(
                            pathlib.Path(self.pre_train_corpus.encoded.
                                         url[len("sqlite:///"):]).parent,
                            self.cache.path,
                        ),
                        symlink,
                    )

            # Create symlink to the tokenizer and create a backup inside checkpoints.
            symlink = self.cache.path / "tokenizer"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(self.corpus.tokenizer_path,
                                    self.cache.path), symlink)
            if (self.cache.path / "checkpoints" /
                    "backup_tokenizer.pkl").exists():
                shutil.copyfile(
                    self.cache.path / "checkpoints" / "backup_tokenizer.pkl",
                    self.corpus.tokenizer_path)

            # Validate metadata against cache.
            if self.cache.get("META.pbtxt"):
                cached_meta = pbutil.FromFile(
                    pathlib.Path(self.cache["META.pbtxt"]),
                    internal_pb2.ModelMeta())
                # Exclude num_epochs and corpus location from metadata comparison.
                config_to_compare = model_pb2.Model()
                config_to_compare.CopyFrom(self.config)
                config_to_compare.corpus.ClearField("contentfiles")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                config_to_compare.training.ClearField("num_epochs")
                config_to_compare.training.ClearField("num_train_steps")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.training.ClearField("num_pretrain_steps")
                config_to_compare.training.ClearField("batch_size")
                if config_to_compare.training.HasField("data_generator"):
                    config_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    config_to_compare.training.data_generator.ClearField(
                        "validation_set")
                # These fields should have already been cleared, but we'll do it again
                # so that metadata comparisons don't fail when the cached meta schema
                # is updated.
                cached_to_compare = model_pb2.Model()
                cached_to_compare.CopyFrom(cached_meta.config)
                cached_to_compare.corpus.ClearField("contentfiles")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                cached_to_compare.training.ClearField("num_epochs")
                cached_to_compare.training.ClearField("num_train_steps")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.training.ClearField("num_pretrain_steps")
                cached_to_compare.training.ClearField("batch_size")
                if cached_to_compare.training.HasField("data_generator"):
                    cached_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    cached_to_compare.training.data_generator.ClearField(
                        "validation_set")
                if cached_to_compare.training.sequence_length != config_to_compare.training.sequence_length:
                    l.logger().warning(
                        "Mismatch between pre-trained and current config sequence_length!\
            This can only be intended in BERT model!")
                cached_to_compare.training.ClearField("sequence_length")
                config_to_compare.training.ClearField("sequence_length")
                if config_to_compare != cached_to_compare:
                    raise SystemError("Metadata mismatch: {} \n\n {}".format(
                        config_to_compare, cached_to_compare))
                self.meta = cached_meta
            else:
                self.meta = internal_pb2.ModelMeta()
                self.meta.config.CopyFrom(self.config)
                self._WriteMetafile()

            ## Store current commit
            commit.saveCommit(self.cache.path)

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW_SEQ:
            tf_sequential.tfSequential,
            model_pb2.NetworkArchitecture.KERAS_SEQ:
            keras_sequential.kerasSequential,
            model_pb2.NetworkArchitecture.TENSORFLOW_BERT: tf_bert.tfBert,
            model_pb2.NetworkArchitecture.TORCH_BERT: torch_bert.torchBert,
        }[config.architecture.backend](self.config, self.cache, self.hash)
        l.logger().info("Initialized {} in {}".format(self.backend,
                                                      self.cache.path))
        return