def export_classifier(model_export_path, input_meta_data, restore_model_using_load_weights): """Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. restore_model_using_load_weights: Whether to use checkpoint.restore() API for custom checkpoint or to use model.load_weights() API. There are 2 different ways to save checkpoints. One is using tf.train.Checkpoint and another is using Keras model.save_weights(). Custom training loop implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint callback internally uses model.save_weights() API. Since these two API's cannot be used toghether, model loading logic must be take into account how model checkpoint was saved. Raises: Export path is not specified, got an empty string or None. """ if not model_export_path: raise ValueError('Export path is not specified: %s' % model_export_path) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) classifier_model = bert_models.classifier_model( bert_config, tf.float32, input_meta_data['num_labels'], input_meta_data['max_seq_length'])[0] model_saving_utils.export_bert_model( model_export_path, model=classifier_model, checkpoint_dir=FLAGS.model_dir, restore_model_using_load_weights=restore_model_using_load_weights)
def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = (bert_models.classifier_model( bert_config, tf.float32, num_classes, max_seq_length)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( classifier_model.optimizer) return classifier_model, core_model
def export_classifier(model_export_path, input_meta_data): """Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. Raises: Export path is not specified, got an empty string or None. """ if not model_export_path: raise ValueError('Export path is not specified: %s' % model_export_path) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) classifier_model = bert_models.classifier_model( bert_config, tf.float32, input_meta_data['num_labels'], input_meta_data['max_seq_length'])[0] model_saving_utils.export_bert_model(model_export_path, model=classifier_model, checkpoint_dir=FLAGS.model_dir)
def predict_customized(strategy, input_meta_data, bert_config, eval_data_path, num_steps): max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] predict_dataset = input_pipeline.create_classifier_dataset( eval_data_path, input_meta_data['max_seq_length'], FLAGS.eval_batch_size, is_training=False) predict_iterator = iter( strategy.experimental_distribute_dataset(predict_dataset)) with strategy.scope(): tf.keras.mixed_precision.experimental.set_policy('float32') classifier_model, _ = (bert_models.classifier_model( bert_config, tf.float32, num_classes, max_seq_length)) checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) logging.info('Restoring checkpoints from %s', checkpoint_path) checkpoint = tf.train.Checkpoint(model=classifier_model) checkpoint.restore(checkpoint_path).expect_partial() @tf.function def predict_step(iterator): """Predicts on distributed devices.""" def _replicated_step(inputs_d): """Replicated prediction calculation.""" inputs, label = inputs_d # x = { # 'input_word_ids': inputs['input_word_ids'], # 'input_mask': inputs['input_mask'], # 'input_type_ids': inputs['input_type_ids'], # } logits = classifier_model(inputs, training=False) return dict(logits=logits, label_ids=label, mask=inputs["is_real_example"]) outputs = strategy.experimental_run_v2(_replicated_step, args=(next(iterator), )) return tf.nest.map_structure(strategy.experimental_local_results, outputs) correct = 0 total = 0 all_results = [] for _ in range(num_steps): predictions = predict_step(predict_iterator) merged_logits = [] merged_labels = [] merged_masks = [] for result in get_raw_results(predictions): all_results.append(result) if len(all_results) % 100 == 0: logging.info('Made predictions for %d records.', len(all_results)) for logits, label_ids, mask in zip(predictions['logits'], predictions['label_ids'], predictions['mask']): merged_logits.append(logits) merged_labels.append(label_ids) merged_masks.append(mask) merged_logits = np.vstack(np.array(merged_logits)) merged_labels = np.hstack(np.array(merged_labels)) merged_masks = np.hstack(np.array(merged_masks)) real_index = np.where(np.equal(merged_masks, 1)) correct += np.sum( np.equal(np.argmax(merged_logits, axis=-1), merged_labels)) total += np.shape(real_index)[-1] accuracy = float(correct) / float(total) logging.info("Train step: %d / acc = %d/%d = %f", num_steps, correct, total, accuracy) return all_results