def _get_squad_model(): """Get Squad model and optimizer.""" squad_model, core_model = bert_models.squad_model( bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps, FLAGS.end_lr, FLAGS.optimizer_type) squad_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16()) return squad_model, core_model
def get_squad_model_to_predict(strategy, bert_config, checkpoint_path, input_meta_data): """Gets a squad model to make predictions.""" with strategy.scope(): # Prediction always uses float32, even if training uses mixed precision. tf.keras.mixed_precision.set_global_policy('float32') squad_model, _ = bert_models.squad_model( bert_config, input_meta_data['max_seq_length'], hub_module_url=FLAGS.hub_module_url) if checkpoint_path is None: checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) logging.info('Restoring checkpoints from %s', checkpoint_path) checkpoint = tf.train.Checkpoint(model=squad_model) checkpoint.restore(checkpoint_path).expect_partial() return squad_model
def export_bert_squad_tfhub(bert_config: configs.BertConfig, model_checkpoint_path: Text, hub_destination: Text, vocab_file: Text, do_lower_case: bool = None): """Restores a tf.keras.Model for BERT with SQuAD and saves for TF-Hub.""" # If do_lower_case is not explicit, default to checking whether "uncased" is # in the vocab file name if do_lower_case is None: do_lower_case = "uncased" in vocab_file logging.info("Using do_lower_case=%s based on name of vocab_file=%s", do_lower_case, vocab_file) span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None) checkpoint = tf.train.Checkpoint(model=span_labeling) checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched() span_labeling.vocab_file = tf.saved_model.Asset(vocab_file) span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False) span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
def test_squad_model(self): model, core_model = bert_models.squad_model(self._bert_test_config, max_seq_length=5, initializer=None, hub_module_url=None, hub_module_trainable=None) self.assertIsInstance(model, tf.keras.Model) self.assertIsInstance(core_model, tf.keras.Model) # Expect two output from model: start positions and end positions self.assertIsInstance(model.output, list) self.assertLen(model.output, 2) # Expect two output from core_model: sequence and classification output. self.assertIsInstance(core_model.output, list) self.assertLen(core_model.output, 2) # shape should be [batch size, None, hidden_size] self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16]) # shape should be [batch size, hidden_size] self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
def export_squad(model_export_path, input_meta_data, bert_config): """Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. bert_config: Bert configuration file to define core bert layers. Raises: Export path is not specified, got an empty string or None. """ if not model_export_path: raise ValueError('Export path is not specified: %s' % model_export_path) # Export uses float32 for now, even if training uses mixed precision. tf.keras.mixed_precision.set_global_policy('float32') squad_model, _ = bert_models.squad_model(bert_config, input_meta_data['max_seq_length']) model_saving_utils.export_bert_model(model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)