def _get_classifier_model(): classifier_model, core_model = ( bert_models.classifier_model(bert_config, tf.float32, num_classes, max_seq_length)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) return classifier_model, core_model
def _get_classifier_model(): """Gets a classifier model.""" classifier_model, core_model = ( bert_models.classifier_model(bert_config, tf.float32, num_classes, max_seq_length)) classifier_model.optimizer = optimization.create_optimizer( initial_lr, steps_per_epoch * epochs, warmup_steps) if FLAGS.fp16_implementation == 'graph_rewrite': # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # which will ensure tf.compat.v2.keras.mixed_precision and # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double # up. classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( classifier_model.optimizer) return classifier_model, core_model
def export_classifier(model_export_path, input_meta_data): """Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. Raises: Export path is not specified, got an empty string or None. """ if not model_export_path: raise ValueError('Export path is not specified: %s' % model_export_path) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) classifier_model = bert_models.classifier_model( bert_config, tf.float32, input_meta_data['num_labels'], input_meta_data['max_seq_length'])[0] model_saving_utils.export_bert_model( model_export_path, model=classifier_model, checkpoint_dir=FLAGS.model_dir)
def _model_fn(): return bert_models.classifier_model( bert_config, tf.float32, input_meta_data['num_labels'], input_meta_data['max_seq_length'])[0]