def test_AssertIsBuildable_architecture_post_layer_dropout_micros( abc_model_config, ): """UserError raised for invalid architecture.post_layer_dropout_micros.""" abc_model_config.architecture.ClearField("post_layer_dropout_micros") with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert ( "NetworkArchitecture.post_layer_dropout_micros must be " ">= 0 and <= 1000000" ) == str(e_info.value) abc_model_config.architecture.post_layer_dropout_micros = -1 with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert ( "NetworkArchitecture.post_layer_dropout_micros must be " ">= 0 and <= 1000000" ) == str(e_info.value) abc_model_config.architecture.post_layer_dropout_micros = 1000001 with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert ( "NetworkArchitecture.post_layer_dropout_micros must be " ">= 0 and <= 1000000" ) == str(e_info.value)
def test_AssertIsBuildable_training_num_epochs(abc_model_config): """UserError is raised if training.num_epochs field invalid.""" abc_model_config.training.ClearField('num_epochs') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "TrainingOptions.num_epochs must be > 0" == str(e_info.value) abc_model_config.training.num_epochs = -1 with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "TrainingOptions.num_epochs must be > 0" == str(e_info.value)
def test_AssertIsBuildable_architecture_num_layers(abc_model_config): """UserError is raised if architecture.num_layers field invalid.""" abc_model_config.architecture.ClearField('num_layers') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "NetworkArchitecture.num_layers must be > 0" == str(e_info.value) abc_model_config.architecture.num_layers = -1 with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "NetworkArchitecture.num_layers must be > 0" == str(e_info.value)
def test_AssertIsBuildable_architecture_embedding_size(abc_model_config): """UserError is raised if architecture.embedding_size field invalid.""" # embedding_size is ignored unless backend == KERAS. abc_model_config.architecture.backend = model_pb2.NetworkArchitecture.KERAS abc_model_config.architecture.ClearField("embedding_size") with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "NetworkArchitecture.embedding_size must be > 0" == str(e_info.value) abc_model_config.architecture.embedding_size = -1 with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "NetworkArchitecture.embedding_size must be > 0" == str(e_info.value)
def test_AssertIsBuildable_adam_optimizer_normalized_gradient_clip_micros( abc_model_config): """UserError if normalized_gradient_clip_micros field is invalid.""" abc_model_config.training.adam_optimizer.ClearField( 'normalized_gradient_clip_micros') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "AdamOptimizer.normalized_gradient_clip_micros must be >= 0" == str( e_info.value) abc_model_config.training.adam_optimizer.normalized_gradient_clip_micros = -1 with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "AdamOptimizer.normalized_gradient_clip_micros must be >= 0" == str( e_info.value)
def test_AssertIsBuildable_adam_optimizer_initial_learning_rate_micros( abc_model_config): """UserError is raised if initial_learning_rate_micros field is invalid.""" abc_model_config.training.adam_optimizer.ClearField( 'initial_learning_rate_micros') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "AdamOptimizer.initial_learning_rate_micros must be >= 0" == str( e_info.value) abc_model_config.training.adam_optimizer.initial_learning_rate_micros = -1 with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "AdamOptimizer.initial_learning_rate_micros must be >= 0" == str( e_info.value)
def test_AssertIsBuildable_architecture_neurons_per_layer(abc_model_config): """UserError is raised if architecture.neurons_per_layer field invalid.""" abc_model_config.architecture.ClearField("neurons_per_layer") with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "NetworkArchitecture.neurons_per_layer must be > 0" == str( e_info.value ) abc_model_config.architecture.neurons_per_layer = -1 with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "NetworkArchitecture.neurons_per_layer must be > 0" == str( e_info.value )
def test_AssertIsBuildable_adam_optimizer_learning_rate_decay_per_epoch_micros( abc_model_config): """UserError if learning_rate_decay_per_epoch_micros field is invalid.""" abc_model_config.training.adam_optimizer.ClearField( 'learning_rate_decay_per_epoch_micros') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert ("AdamOptimizer.learning_rate_decay_per_epoch_micros " "must be >= 0") == str(e_info.value) abc_model_config.training.adam_optimizer.learning_rate_decay_per_epoch_micros = -1 with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert ("AdamOptimizer.learning_rate_decay_per_epoch_micros " "must be >= 0") == str(e_info.value)
def test_AssertIsBuildable_architecture_neuron_type(abc_model_config): """UserError is raised if architecture.neuron_type field not set.""" abc_model_config.architecture.ClearField('neuron_type') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "Field not set: 'NetworkArchitecture.neuron_type'" == str( e_info.value)
def test_AssertIsBuildable_adam_optimizer_beta_2_micros(abc_model_config): """UserError if beta_2_micros field is invalid.""" abc_model_config.training.adam_optimizer.ClearField('beta_2_micros') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000" == str( e_info.value) abc_model_config.training.adam_optimizer.beta_2_micros = -1 with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000" == str( e_info.value) abc_model_config.training.adam_optimizer.beta_2_micros = 1000001 with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "AdamOptimizer.beta_2_micros must be >= 0 and <= 1000000" == str( e_info.value)
def test_AssertIsBuildable_training_shuffle_corpus_contentfiles_between_epochs( abc_model_config): """UserError if field not set.""" abc_model_config.training.ClearField( 'shuffle_corpus_contentfiles_between_epochs') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert ("Field not set: 'TrainingOptions." "shuffle_corpus_contentfiles_between_epochs'") == str(e_info.value) abc_model_config.training.shuffle_corpus_contentfiles_between_epochs = -1
def test_AssertIsBuildable_no_training(abc_model_config): """Test that UserError is raised if training field not set.""" abc_model_config.ClearField('training') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "Field not set: 'Model.training'" == str(e_info.value)
def test_AssertIsBuildable_no_architecture(abc_model_config): """Test that UserError is raised if architecture field not set.""" abc_model_config.ClearField('architecture') with pytest.raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "Field not set: 'Model.architecture'" == str(e_info.value)
def test_AssertIsBuildable_returns_config(abc_model_config): """Test that the original config is returned.""" assert abc_model_config == builders.AssertIsBuildable(abc_model_config)
def test_AssertIsBuildable_no_corpus(abc_model_config): """Test that UserError is raised if corpus field not set.""" abc_model_config.ClearField("corpus") with test.Raises(errors.UserError) as e_info: builders.AssertIsBuildable(abc_model_config) assert "Field not set: 'Model.corpus'" == str(e_info.value)
def __init__(self, config: model_pb2.Model): """Instantiate a model. Args: config: A Model message. Raises: TypeError: If the config argument is not a Model proto. UserError: In case on an invalid config. """ # Error early, so that a cache isn't created. if not isinstance(config, model_pb2.Model): t = type(config).__name__ raise TypeError(f"Config must be a Model proto. Received: '{t}'") # Validate config options. if config.training.sequence_length < 1: raise errors.UserError( 'TrainingOptions.sequence_length must be >= 1') self.config = model_pb2.Model() self.config.CopyFrom(builders.AssertIsBuildable(config)) self.corpus = corpuses.Corpus(config.corpus) self.hash = self._ComputeHash(self.corpus, self.config) self.cache = cache.mkcache('model', self.hash) # Create the necessary cache directories. (self.cache.path / 'checkpoints').mkdir(exist_ok=True) (self.cache.path / 'samples').mkdir(exist_ok=True) (self.cache.path / 'logs').mkdir(exist_ok=True) # Create symlink to encoded corpus. symlink = self.cache.path / 'corpus' if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path( self.corpus.encoded.url[len('sqlite:///'):]).parent, self.cache.path), symlink) # Create symlink to the atomizer. symlink = self.cache.path / 'atomizer' if not symlink.is_symlink(): os.symlink( os.path.relpath(self.corpus.atomizer_path, self.cache.path), symlink) # Validate metadata against cache. if self.cache.get('META.pbtxt'): cached_meta = pbutil.FromFile( pathlib.Path(self.cache['META.pbtxt']), internal_pb2.ModelMeta()) # Exclude num_epochs and corpus location from metadata comparison. config_to_compare = model_pb2.Model() config_to_compare.CopyFrom(self.config) config_to_compare.corpus.ClearField('contentfiles') config_to_compare.training.ClearField('num_epochs') # These fields should have already been cleared, but we'll do it again # so that metadata comparisons don't fail when the cached meta schema # is updated. cached_to_compare = model_pb2.Model() cached_to_compare.CopyFrom(cached_meta.config) cached_to_compare.corpus.ClearField('contentfiles') cached_to_compare.training.ClearField('num_epochs') if config_to_compare != cached_to_compare: raise errors.InternalError('Metadata mismatch') self.meta = cached_meta else: self.meta = internal_pb2.ModelMeta() self.meta.config.CopyFrom(self.config) self._WriteMetafile() self.backend = { model_pb2.NetworkArchitecture.TENSORFLOW: tensorflow_backend.TensorFlowBackend, model_pb2.NetworkArchitecture.KERAS: keras_backend.KerasBackend, }[config.architecture.backend](self.config, self.cache, self.corpus)
def __init__(self, config: model_pb2.Model): """Instantiate a model. Args: config: A Model message. Raises: TypeError: If the config argument is not a Model proto. UserError: In case on an invalid config. """ # Error early, so that a cache isn't created. if not isinstance(config, model_pb2.Model): t = type(config).__name__ raise TypeError(f"Config must be a Model proto. Received: '{t}'") self.config = model_pb2.Model() # Validate config options. self.config.CopyFrom(builders.AssertIsBuildable(config)) if FLAGS.num_train_steps: self.config.training.num_train_steps = FLAGS.num_train_steps if FLAGS.num_pretrain_steps: self.config.training.num_pretrain_steps = FLAGS.num_pretrain_steps if FLAGS.num_epochs: self.config.training.num_epochs = FLAGS.num_epochs # Initialize distrib lock path. if environment.WORLD_SIZE > 1: if environment.WORLD_RANK == 0: lock_cache = cache.mkcache("locks") lock_cache.path.mkdir(exist_ok=True) else: while not cache.cachepath("locks").exists(): time.sleep(0.5) lock_cache = cache.mkcache("locks") distrib.init(lock_cache.path) # Initialize corpuses self.corpus = corpuses.Corpus(config.corpus) self.pre_train_corpus = None if config.HasField("pre_train_corpus"): self.pre_train_corpus = corpuses.Corpus(config.pre_train_corpus) self.hash = self._ComputeHash(self.pre_train_corpus, self.corpus, self.config) self._created = False distrib.lock() self.cache = cache.mkcache("model", self.hash) distrib.unlock() if environment.WORLD_RANK == 0: # Create the necessary cache directories. (self.cache.path / "checkpoints").mkdir(exist_ok=True) (self.cache.path / "samples").mkdir(exist_ok=True) # Create symlink to encoded corpus. symlink = self.cache.path / "corpus" if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path(self.corpus.encoded.url[len("sqlite:///" ):]).parent, self.cache.path, ), symlink, ) if self.pre_train_corpus: symlink = self.cache.path / "pre_train_corpus" if not symlink.is_symlink(): os.symlink( os.path.relpath( pathlib.Path(self.pre_train_corpus.encoded. url[len("sqlite:///"):]).parent, self.cache.path, ), symlink, ) # Create symlink to the tokenizer and create a backup inside checkpoints. symlink = self.cache.path / "tokenizer" if not symlink.is_symlink(): os.symlink( os.path.relpath(self.corpus.tokenizer_path, self.cache.path), symlink) if (self.cache.path / "checkpoints" / "backup_tokenizer.pkl").exists(): shutil.copyfile( self.cache.path / "checkpoints" / "backup_tokenizer.pkl", self.corpus.tokenizer_path) # Validate metadata against cache. if self.cache.get("META.pbtxt"): cached_meta = pbutil.FromFile( pathlib.Path(self.cache["META.pbtxt"]), internal_pb2.ModelMeta()) # Exclude num_epochs and corpus location from metadata comparison. config_to_compare = model_pb2.Model() config_to_compare.CopyFrom(self.config) config_to_compare.corpus.ClearField("contentfiles") if config_to_compare.HasField("pre_train_corpus"): config_to_compare.pre_train_corpus.ClearField( "contentfiles") config_to_compare.training.ClearField("num_epochs") config_to_compare.training.ClearField("num_train_steps") if config_to_compare.HasField("pre_train_corpus"): config_to_compare.training.ClearField("num_pretrain_steps") config_to_compare.training.ClearField("batch_size") if config_to_compare.training.HasField("data_generator"): config_to_compare.training.data_generator.ClearField( "steps_per_epoch") config_to_compare.training.data_generator.ClearField( "validation_set") # These fields should have already been cleared, but we'll do it again # so that metadata comparisons don't fail when the cached meta schema # is updated. cached_to_compare = model_pb2.Model() cached_to_compare.CopyFrom(cached_meta.config) cached_to_compare.corpus.ClearField("contentfiles") if cached_to_compare.HasField("pre_train_corpus"): cached_to_compare.pre_train_corpus.ClearField( "contentfiles") cached_to_compare.training.ClearField("num_epochs") cached_to_compare.training.ClearField("num_train_steps") if cached_to_compare.HasField("pre_train_corpus"): cached_to_compare.training.ClearField("num_pretrain_steps") cached_to_compare.training.ClearField("batch_size") if cached_to_compare.training.HasField("data_generator"): cached_to_compare.training.data_generator.ClearField( "steps_per_epoch") cached_to_compare.training.data_generator.ClearField( "validation_set") if cached_to_compare.training.sequence_length != config_to_compare.training.sequence_length: l.logger().warning( "Mismatch between pre-trained and current config sequence_length!\ This can only be intended in BERT model!") cached_to_compare.training.ClearField("sequence_length") config_to_compare.training.ClearField("sequence_length") if config_to_compare != cached_to_compare: raise SystemError("Metadata mismatch: {} \n\n {}".format( config_to_compare, cached_to_compare)) self.meta = cached_meta else: self.meta = internal_pb2.ModelMeta() self.meta.config.CopyFrom(self.config) self._WriteMetafile() ## Store current commit commit.saveCommit(self.cache.path) self.backend = { model_pb2.NetworkArchitecture.TENSORFLOW_SEQ: tf_sequential.tfSequential, model_pb2.NetworkArchitecture.KERAS_SEQ: keras_sequential.kerasSequential, model_pb2.NetworkArchitecture.TENSORFLOW_BERT: tf_bert.tfBert, model_pb2.NetworkArchitecture.TORCH_BERT: torch_bert.torchBert, }[config.architecture.backend](self.config, self.cache, self.hash) l.logger().info("Initialized {} in {}".format(self.backend, self.cache.path)) return