예제 #1
0
 def _get_classifier_model():
   classifier_model, core_model = (
       bert_models.classifier_model(bert_config, tf.float32, num_classes,
                                    max_seq_length))
   classifier_model.optimizer = optimization.create_optimizer(
       initial_lr, steps_per_epoch * epochs, warmup_steps)
   return classifier_model, core_model
예제 #2
0
 def _get_classifier_model():
   """Gets a classifier model."""
   classifier_model, core_model = (
       bert_models.classifier_model(bert_config, tf.float32, num_classes,
                                    max_seq_length))
   classifier_model.optimizer = optimization.create_optimizer(
       initial_lr, steps_per_epoch * epochs, warmup_steps)
   if FLAGS.fp16_implementation == 'graph_rewrite':
     # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
     # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
     # which will ensure tf.compat.v2.keras.mixed_precision and
     # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
     # up.
     classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
         classifier_model.optimizer)
   return classifier_model, core_model
예제 #3
0
def export_classifier(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  classifier_model = bert_models.classifier_model(
      bert_config, tf.float32, input_meta_data['num_labels'],
      input_meta_data['max_seq_length'])[0]
  model_saving_utils.export_bert_model(
      model_export_path, model=classifier_model, checkpoint_dir=FLAGS.model_dir)
예제 #4
0
 def _model_fn():
     return bert_models.classifier_model(
         bert_config, tf.float32, input_meta_data['num_labels'],
         input_meta_data['max_seq_length'])[0]