def _get_squad_model():
        """Get Squad model and optimizer."""
        squad_model, core_model = bert_models.squad_model(
            bert_config,
            max_seq_length,
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable)
        optimizer = optimization.create_optimizer(FLAGS.learning_rate,
                                                  steps_per_epoch * epochs,
                                                  warmup_steps, FLAGS.end_lr,
                                                  FLAGS.optimizer_type)

        squad_model.optimizer = performance.configure_optimizer(
            optimizer, use_float16=common_flags.use_float16())
        return squad_model, core_model
def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
                               input_meta_data):
    """Gets a squad model to make predictions."""
    with strategy.scope():
        # Prediction always uses float32, even if training uses mixed precision.
        tf.keras.mixed_precision.set_global_policy('float32')
        squad_model, _ = bert_models.squad_model(
            bert_config,
            input_meta_data['max_seq_length'],
            hub_module_url=FLAGS.hub_module_url)

    if checkpoint_path is None:
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    logging.info('Restoring checkpoints from %s', checkpoint_path)
    checkpoint = tf.train.Checkpoint(model=squad_model)
    checkpoint.restore(checkpoint_path).expect_partial()
    return squad_model
def export_bert_squad_tfhub(bert_config: configs.BertConfig,
                            model_checkpoint_path: Text,
                            hub_destination: Text,
                            vocab_file: Text,
                            do_lower_case: bool = None):
  """Restores a tf.keras.Model for BERT with SQuAD and saves for TF-Hub."""
  # If do_lower_case is not explicit, default to checking whether "uncased" is
  # in the vocab file name
  if do_lower_case is None:
    do_lower_case = "uncased" in vocab_file
    logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
                 do_lower_case, vocab_file)
  span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None)
  checkpoint = tf.train.Checkpoint(model=span_labeling)
  checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
  span_labeling.vocab_file = tf.saved_model.Asset(vocab_file)
  span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False)
  span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
    def test_squad_model(self):
        model, core_model = bert_models.squad_model(self._bert_test_config,
                                                    max_seq_length=5,
                                                    initializer=None,
                                                    hub_module_url=None,
                                                    hub_module_trainable=None)
        self.assertIsInstance(model, tf.keras.Model)
        self.assertIsInstance(core_model, tf.keras.Model)

        # Expect two output from model: start positions and end positions
        self.assertIsInstance(model.output, list)
        self.assertLen(model.output, 2)

        # Expect two output from core_model: sequence and classification output.
        self.assertIsInstance(core_model.output, list)
        self.assertLen(core_model.output, 2)
        # shape should be [batch size, None, hidden_size]
        self.assertEqual(core_model.output[0].shape.as_list(),
                         [None, None, 16])
        # shape should be [batch size, hidden_size]
        self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
def export_squad(model_export_path, input_meta_data, bert_config):
    """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
    bert_config: Bert configuration file to define core bert layers.

  Raises:
    Export path is not specified, got an empty string or None.
  """
    if not model_export_path:
        raise ValueError('Export path is not specified: %s' %
                         model_export_path)
    # Export uses float32 for now, even if training uses mixed precision.
    tf.keras.mixed_precision.set_global_policy('float32')
    squad_model, _ = bert_models.squad_model(bert_config,
                                             input_meta_data['max_seq_length'])
    model_saving_utils.export_bert_model(model_export_path,
                                         model=squad_model,
                                         checkpoint_dir=FLAGS.model_dir)