def export_classifier(model_export_path, input_meta_data,
                      restore_model_using_load_weights):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
    restore_model_using_load_weights: Whether to use checkpoint.restore() API
      for custom checkpoint or to use model.load_weights() API.
      There are 2 different ways to save checkpoints. One is using
      tf.train.Checkpoint and another is using Keras model.save_weights().
      Custom training loop implementation uses tf.train.Checkpoint API
      and Keras ModelCheckpoint callback internally uses model.save_weights()
      API. Since these two API's cannot be used toghether, model loading logic
      must be take into account how model checkpoint was saved.

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  classifier_model = bert_models.classifier_model(
      bert_config, tf.float32, input_meta_data['num_labels'],
      input_meta_data['max_seq_length'])[0]

  model_saving_utils.export_bert_model(
      model_export_path,
      model=classifier_model,
      checkpoint_dir=FLAGS.model_dir,
      restore_model_using_load_weights=restore_model_using_load_weights)
Пример #2
0
 def _get_classifier_model():
     """Gets a classifier model."""
     classifier_model, core_model = (bert_models.classifier_model(
         bert_config, tf.float32, num_classes, max_seq_length))
     classifier_model.optimizer = optimization.create_optimizer(
         initial_lr, steps_per_epoch * epochs, warmup_steps)
     if FLAGS.fp16_implementation == 'graph_rewrite':
         # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
         # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
         # which will ensure tf.compat.v2.keras.mixed_precision and
         # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
         # up.
         classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
             classifier_model.optimizer)
     return classifier_model, core_model
Пример #3
0
def export_classifier(model_export_path, input_meta_data):
    """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
    if not model_export_path:
        raise ValueError('Export path is not specified: %s' %
                         model_export_path)
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    classifier_model = bert_models.classifier_model(
        bert_config, tf.float32, input_meta_data['num_labels'],
        input_meta_data['max_seq_length'])[0]
    model_saving_utils.export_bert_model(model_export_path,
                                         model=classifier_model,
                                         checkpoint_dir=FLAGS.model_dir)
Пример #4
0
def predict_customized(strategy, input_meta_data, bert_config, eval_data_path,
                       num_steps):
    max_seq_length = input_meta_data['max_seq_length']
    num_classes = input_meta_data['num_labels']
    predict_dataset = input_pipeline.create_classifier_dataset(
        eval_data_path,
        input_meta_data['max_seq_length'],
        FLAGS.eval_batch_size,
        is_training=False)
    predict_iterator = iter(
        strategy.experimental_distribute_dataset(predict_dataset))
    with strategy.scope():
        tf.keras.mixed_precision.experimental.set_policy('float32')
        classifier_model, _ = (bert_models.classifier_model(
            bert_config, tf.float32, num_classes, max_seq_length))
    checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    logging.info('Restoring checkpoints from %s', checkpoint_path)
    checkpoint = tf.train.Checkpoint(model=classifier_model)
    checkpoint.restore(checkpoint_path).expect_partial()

    @tf.function
    def predict_step(iterator):
        """Predicts on distributed devices."""
        def _replicated_step(inputs_d):
            """Replicated prediction calculation."""
            inputs, label = inputs_d
            # x = {
            #     'input_word_ids': inputs['input_word_ids'],
            #     'input_mask': inputs['input_mask'],
            #     'input_type_ids': inputs['input_type_ids'],
            # }
            logits = classifier_model(inputs, training=False)
            return dict(logits=logits,
                        label_ids=label,
                        mask=inputs["is_real_example"])

        outputs = strategy.experimental_run_v2(_replicated_step,
                                               args=(next(iterator), ))
        return tf.nest.map_structure(strategy.experimental_local_results,
                                     outputs)

    correct = 0
    total = 0
    all_results = []
    for _ in range(num_steps):
        predictions = predict_step(predict_iterator)
        merged_logits = []
        merged_labels = []
        merged_masks = []
        for result in get_raw_results(predictions):
            all_results.append(result)
        if len(all_results) % 100 == 0:
            logging.info('Made predictions for %d records.', len(all_results))
        for logits, label_ids, mask in zip(predictions['logits'],
                                           predictions['label_ids'],
                                           predictions['mask']):
            merged_logits.append(logits)
            merged_labels.append(label_ids)
            merged_masks.append(mask)
        merged_logits = np.vstack(np.array(merged_logits))
        merged_labels = np.hstack(np.array(merged_labels))
        merged_masks = np.hstack(np.array(merged_masks))
        real_index = np.where(np.equal(merged_masks, 1))
        correct += np.sum(
            np.equal(np.argmax(merged_logits, axis=-1), merged_labels))
        total += np.shape(real_index)[-1]
    accuracy = float(correct) / float(total)
    logging.info("Train step: %d  /  acc = %d/%d = %f", num_steps, correct,
                 total, accuracy)
    return all_results