示例#1
0
def test_AssertIsBuildable_architecture_post_layer_dropout_micros(
  abc_model_config,
):
  """UserError raised for invalid architecture.post_layer_dropout_micros."""
  abc_model_config.architecture.ClearField("post_layer_dropout_micros")
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert (
    "NetworkArchitecture.post_layer_dropout_micros must be "
    ">= 0 and <= 1000000"
  ) == str(e_info.value)
  abc_model_config.architecture.post_layer_dropout_micros = -1
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert (
    "NetworkArchitecture.post_layer_dropout_micros must be "
    ">= 0 and <= 1000000"
  ) == str(e_info.value)
  abc_model_config.architecture.post_layer_dropout_micros = 1000001
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert (
    "NetworkArchitecture.post_layer_dropout_micros must be "
    ">= 0 and <= 1000000"
  ) == str(e_info.value)
示例#2
0
def test_AssertIsBuildable_training_num_epochs(abc_model_config):
    """UserError is raised if training.num_epochs field invalid."""
    abc_model_config.training.ClearField('num_epochs')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "TrainingOptions.num_epochs must be > 0" == str(e_info.value)
    abc_model_config.training.num_epochs = -1
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "TrainingOptions.num_epochs must be > 0" == str(e_info.value)
示例#3
0
def test_AssertIsBuildable_architecture_num_layers(abc_model_config):
    """UserError is raised if architecture.num_layers field invalid."""
    abc_model_config.architecture.ClearField('num_layers')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "NetworkArchitecture.num_layers must be > 0" == str(e_info.value)
    abc_model_config.architecture.num_layers = -1
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "NetworkArchitecture.num_layers must be > 0" == str(e_info.value)
示例#4
0
def test_AssertIsBuildable_architecture_embedding_size(abc_model_config):
  """UserError is raised if architecture.embedding_size field invalid."""
  # embedding_size is ignored unless backend == KERAS.
  abc_model_config.architecture.backend = model_pb2.NetworkArchitecture.KERAS
  abc_model_config.architecture.ClearField("embedding_size")
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert "NetworkArchitecture.embedding_size must be > 0" == str(e_info.value)
  abc_model_config.architecture.embedding_size = -1
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert "NetworkArchitecture.embedding_size must be > 0" == str(e_info.value)
示例#5
0
def test_AssertIsBuildable_adam_optimizer_normalized_gradient_clip_micros(
        abc_model_config):
    """UserError if normalized_gradient_clip_micros field is invalid."""
    abc_model_config.training.adam_optimizer.ClearField(
        'normalized_gradient_clip_micros')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "AdamOptimizer.normalized_gradient_clip_micros must be >= 0" == str(
        e_info.value)
    abc_model_config.training.adam_optimizer.normalized_gradient_clip_micros = -1
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "AdamOptimizer.normalized_gradient_clip_micros must be >= 0" == str(
        e_info.value)
示例#6
0
def test_AssertIsBuildable_adam_optimizer_initial_learning_rate_micros(
        abc_model_config):
    """UserError is raised if initial_learning_rate_micros field is invalid."""
    abc_model_config.training.adam_optimizer.ClearField(
        'initial_learning_rate_micros')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "AdamOptimizer.initial_learning_rate_micros must be >= 0" == str(
        e_info.value)
    abc_model_config.training.adam_optimizer.initial_learning_rate_micros = -1
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "AdamOptimizer.initial_learning_rate_micros must be >= 0" == str(
        e_info.value)
示例#7
0
def test_AssertIsBuildable_architecture_neurons_per_layer(abc_model_config):
  """UserError is raised if architecture.neurons_per_layer field invalid."""
  abc_model_config.architecture.ClearField("neurons_per_layer")
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert "NetworkArchitecture.neurons_per_layer must be > 0" == str(
    e_info.value
  )
  abc_model_config.architecture.neurons_per_layer = -1
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert "NetworkArchitecture.neurons_per_layer must be > 0" == str(
    e_info.value
  )
示例#8
0
def test_AssertIsBuildable_adam_optimizer_learning_rate_decay_per_epoch_micros(
        abc_model_config):
    """UserError if learning_rate_decay_per_epoch_micros field is invalid."""
    abc_model_config.training.adam_optimizer.ClearField(
        'learning_rate_decay_per_epoch_micros')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert ("AdamOptimizer.learning_rate_decay_per_epoch_micros "
            "must be >= 0") == str(e_info.value)
    abc_model_config.training.adam_optimizer.learning_rate_decay_per_epoch_micros = -1
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert ("AdamOptimizer.learning_rate_decay_per_epoch_micros "
            "must be >= 0") == str(e_info.value)
示例#9
0
def test_AssertIsBuildable_architecture_neuron_type(abc_model_config):
    """UserError is raised if architecture.neuron_type field not set."""
    abc_model_config.architecture.ClearField('neuron_type')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "Field not set: 'NetworkArchitecture.neuron_type'" == str(
        e_info.value)
示例#10
0
def test_AssertIsBuildable_adam_optimizer_beta_2_micros(abc_model_config):
    """UserError if beta_2_micros field is invalid."""
    abc_model_config.training.adam_optimizer.ClearField('beta_2_micros')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000" == str(
        e_info.value)
    abc_model_config.training.adam_optimizer.beta_2_micros = -1
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000" == str(
        e_info.value)
    abc_model_config.training.adam_optimizer.beta_2_micros = 1000001
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000" == str(
        e_info.value)
示例#11
0
def test_AssertIsBuildable_training_shuffle_corpus_contentfiles_between_epochs(
        abc_model_config):
    """UserError if field not set."""
    abc_model_config.training.ClearField(
        'shuffle_corpus_contentfiles_between_epochs')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert ("Field not set: 'TrainingOptions."
            "shuffle_corpus_contentfiles_between_epochs'") == str(e_info.value)
    abc_model_config.training.shuffle_corpus_contentfiles_between_epochs = -1
示例#12
0
def test_AssertIsBuildable_no_training(abc_model_config):
    """Test that UserError is raised if training field not set."""
    abc_model_config.ClearField('training')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "Field not set: 'Model.training'" == str(e_info.value)
示例#13
0
def test_AssertIsBuildable_no_architecture(abc_model_config):
    """Test that UserError is raised if architecture field not set."""
    abc_model_config.ClearField('architecture')
    with pytest.raises(errors.UserError) as e_info:
        builders.AssertIsBuildable(abc_model_config)
    assert "Field not set: 'Model.architecture'" == str(e_info.value)
示例#14
0
def test_AssertIsBuildable_returns_config(abc_model_config):
    """Test that the original config is returned."""
    assert abc_model_config == builders.AssertIsBuildable(abc_model_config)
示例#15
0
def test_AssertIsBuildable_no_corpus(abc_model_config):
  """Test that UserError is raised if corpus field not set."""
  abc_model_config.ClearField("corpus")
  with test.Raises(errors.UserError) as e_info:
    builders.AssertIsBuildable(abc_model_config)
  assert "Field not set: 'Model.corpus'" == str(e_info.value)
示例#16
0
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")
        # Validate config options.
        if config.training.sequence_length < 1:
            raise errors.UserError(
                'TrainingOptions.sequence_length must be >= 1')

        self.config = model_pb2.Model()
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        self.corpus = corpuses.Corpus(config.corpus)
        self.hash = self._ComputeHash(self.corpus, self.config)
        self.cache = cache.mkcache('model', self.hash)
        # Create the necessary cache directories.
        (self.cache.path / 'checkpoints').mkdir(exist_ok=True)
        (self.cache.path / 'samples').mkdir(exist_ok=True)
        (self.cache.path / 'logs').mkdir(exist_ok=True)

        # Create symlink to encoded corpus.
        symlink = self.cache.path / 'corpus'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(
                    pathlib.Path(
                        self.corpus.encoded.url[len('sqlite:///'):]).parent,
                    self.cache.path), symlink)

        # Create symlink to the atomizer.
        symlink = self.cache.path / 'atomizer'
        if not symlink.is_symlink():
            os.symlink(
                os.path.relpath(self.corpus.atomizer_path, self.cache.path),
                symlink)

        # Validate metadata against cache.
        if self.cache.get('META.pbtxt'):
            cached_meta = pbutil.FromFile(
                pathlib.Path(self.cache['META.pbtxt']),
                internal_pb2.ModelMeta())
            # Exclude num_epochs and corpus location from metadata comparison.
            config_to_compare = model_pb2.Model()
            config_to_compare.CopyFrom(self.config)
            config_to_compare.corpus.ClearField('contentfiles')
            config_to_compare.training.ClearField('num_epochs')
            # These fields should have already been cleared, but we'll do it again
            # so that metadata comparisons don't fail when the cached meta schema
            # is updated.
            cached_to_compare = model_pb2.Model()
            cached_to_compare.CopyFrom(cached_meta.config)
            cached_to_compare.corpus.ClearField('contentfiles')
            cached_to_compare.training.ClearField('num_epochs')
            if config_to_compare != cached_to_compare:
                raise errors.InternalError('Metadata mismatch')
            self.meta = cached_meta
        else:
            self.meta = internal_pb2.ModelMeta()
            self.meta.config.CopyFrom(self.config)
            self._WriteMetafile()

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW:
            tensorflow_backend.TensorFlowBackend,
            model_pb2.NetworkArchitecture.KERAS: keras_backend.KerasBackend,
        }[config.architecture.backend](self.config, self.cache, self.corpus)
示例#17
0
    def __init__(self, config: model_pb2.Model):
        """Instantiate a model.

    Args:
      config: A Model message.

    Raises:
      TypeError: If the config argument is not a Model proto.
      UserError: In case on an invalid config.
    """
        # Error early, so that a cache isn't created.
        if not isinstance(config, model_pb2.Model):
            t = type(config).__name__
            raise TypeError(f"Config must be a Model proto. Received: '{t}'")

        self.config = model_pb2.Model()
        # Validate config options.
        self.config.CopyFrom(builders.AssertIsBuildable(config))
        if FLAGS.num_train_steps:
            self.config.training.num_train_steps = FLAGS.num_train_steps
        if FLAGS.num_pretrain_steps:
            self.config.training.num_pretrain_steps = FLAGS.num_pretrain_steps
        if FLAGS.num_epochs:
            self.config.training.num_epochs = FLAGS.num_epochs

        # Initialize distrib lock path.
        if environment.WORLD_SIZE > 1:
            if environment.WORLD_RANK == 0:
                lock_cache = cache.mkcache("locks")
                lock_cache.path.mkdir(exist_ok=True)
            else:
                while not cache.cachepath("locks").exists():
                    time.sleep(0.5)
                lock_cache = cache.mkcache("locks")
            distrib.init(lock_cache.path)

        # Initialize corpuses
        self.corpus = corpuses.Corpus(config.corpus)
        self.pre_train_corpus = None
        if config.HasField("pre_train_corpus"):
            self.pre_train_corpus = corpuses.Corpus(config.pre_train_corpus)

        self.hash = self._ComputeHash(self.pre_train_corpus, self.corpus,
                                      self.config)
        self._created = False

        distrib.lock()
        self.cache = cache.mkcache("model", self.hash)
        distrib.unlock()

        if environment.WORLD_RANK == 0:
            # Create the necessary cache directories.
            (self.cache.path / "checkpoints").mkdir(exist_ok=True)
            (self.cache.path / "samples").mkdir(exist_ok=True)
            # Create symlink to encoded corpus.
            symlink = self.cache.path / "corpus"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(
                        pathlib.Path(self.corpus.encoded.url[len("sqlite:///"
                                                                 ):]).parent,
                        self.cache.path,
                    ),
                    symlink,
                )
            if self.pre_train_corpus:
                symlink = self.cache.path / "pre_train_corpus"
                if not symlink.is_symlink():
                    os.symlink(
                        os.path.relpath(
                            pathlib.Path(self.pre_train_corpus.encoded.
                                         url[len("sqlite:///"):]).parent,
                            self.cache.path,
                        ),
                        symlink,
                    )

            # Create symlink to the tokenizer and create a backup inside checkpoints.
            symlink = self.cache.path / "tokenizer"
            if not symlink.is_symlink():
                os.symlink(
                    os.path.relpath(self.corpus.tokenizer_path,
                                    self.cache.path), symlink)
            if (self.cache.path / "checkpoints" /
                    "backup_tokenizer.pkl").exists():
                shutil.copyfile(
                    self.cache.path / "checkpoints" / "backup_tokenizer.pkl",
                    self.corpus.tokenizer_path)

            # Validate metadata against cache.
            if self.cache.get("META.pbtxt"):
                cached_meta = pbutil.FromFile(
                    pathlib.Path(self.cache["META.pbtxt"]),
                    internal_pb2.ModelMeta())
                # Exclude num_epochs and corpus location from metadata comparison.
                config_to_compare = model_pb2.Model()
                config_to_compare.CopyFrom(self.config)
                config_to_compare.corpus.ClearField("contentfiles")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                config_to_compare.training.ClearField("num_epochs")
                config_to_compare.training.ClearField("num_train_steps")
                if config_to_compare.HasField("pre_train_corpus"):
                    config_to_compare.training.ClearField("num_pretrain_steps")
                config_to_compare.training.ClearField("batch_size")
                if config_to_compare.training.HasField("data_generator"):
                    config_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    config_to_compare.training.data_generator.ClearField(
                        "validation_set")
                # These fields should have already been cleared, but we'll do it again
                # so that metadata comparisons don't fail when the cached meta schema
                # is updated.
                cached_to_compare = model_pb2.Model()
                cached_to_compare.CopyFrom(cached_meta.config)
                cached_to_compare.corpus.ClearField("contentfiles")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.pre_train_corpus.ClearField(
                        "contentfiles")
                cached_to_compare.training.ClearField("num_epochs")
                cached_to_compare.training.ClearField("num_train_steps")
                if cached_to_compare.HasField("pre_train_corpus"):
                    cached_to_compare.training.ClearField("num_pretrain_steps")
                cached_to_compare.training.ClearField("batch_size")
                if cached_to_compare.training.HasField("data_generator"):
                    cached_to_compare.training.data_generator.ClearField(
                        "steps_per_epoch")
                    cached_to_compare.training.data_generator.ClearField(
                        "validation_set")
                if cached_to_compare.training.sequence_length != config_to_compare.training.sequence_length:
                    l.logger().warning(
                        "Mismatch between pre-trained and current config sequence_length!\
            This can only be intended in BERT model!")
                cached_to_compare.training.ClearField("sequence_length")
                config_to_compare.training.ClearField("sequence_length")
                if config_to_compare != cached_to_compare:
                    raise SystemError("Metadata mismatch: {} \n\n {}".format(
                        config_to_compare, cached_to_compare))
                self.meta = cached_meta
            else:
                self.meta = internal_pb2.ModelMeta()
                self.meta.config.CopyFrom(self.config)
                self._WriteMetafile()

            ## Store current commit
            commit.saveCommit(self.cache.path)

        self.backend = {
            model_pb2.NetworkArchitecture.TENSORFLOW_SEQ:
            tf_sequential.tfSequential,
            model_pb2.NetworkArchitecture.KERAS_SEQ:
            keras_sequential.kerasSequential,
            model_pb2.NetworkArchitecture.TENSORFLOW_BERT: tf_bert.tfBert,
            model_pb2.NetworkArchitecture.TORCH_BERT: torch_bert.torchBert,
        }[config.architecture.backend](self.config, self.cache, self.hash)
        l.logger().info("Initialized {} in {}".format(self.backend,
                                                      self.cache.path))
        return