def export_squad(model_export_path, input_meta_data): """Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. Raises: Export path is not specified, got an empty string or None. """ if not model_export_path: raise ValueError('Export path is not specified: %s' % model_export_path) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) squad_model, _ = bert_models.squad_model( bert_config, input_meta_data['max_seq_length'], float_type=tf.float32) model_saving_utils.export_bert_model( model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
def run_bert(strategy, input_meta_data): """Run BERT training.""" if FLAGS.mode == 'export_only': export_classifier(FLAGS.model_export_path, input_meta_data) return if FLAGS.mode != 'train_and_eval': raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) epochs = FLAGS.num_train_epochs train_data_size = input_meta_data['train_data_size'] steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) eval_steps = int( math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) if not strategy: raise ValueError('Distribution strategy has not been specified.') # Runs customized training loop. logging.info( 'Training using customized training loop TF 2.0 with distrubuted' 'strategy.') use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu) trained_model = run_customized_training(strategy, bert_config, input_meta_data, FLAGS.model_dir, epochs, steps_per_epoch, FLAGS.steps_per_loop, eval_steps, warmup_steps, FLAGS.learning_rate, FLAGS.init_checkpoint, use_remote_tpu=use_remote_tpu) if FLAGS.model_export_path: with tf.device( model_training_utils.get_primary_cpu_task(use_remote_tpu)): model_saving_utils.export_bert_model(FLAGS.model_export_path, model=trained_model) return trained_model
def run_bert(strategy, input_meta_data): """Run BERT training.""" if FLAGS.mode == 'export_only': export_classifier(FLAGS.model_export_path, input_meta_data) return if FLAGS.mode != 'train_and_eval': raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) # Enables XLA in Session Config. Should not be set for TPU. keras_utils.set_config_v2(FLAGS.enable_xla) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) epochs = FLAGS.num_train_epochs train_data_size = input_meta_data['train_data_size'] steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) eval_steps = int( math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) if not strategy: raise ValueError('Distribution strategy has not been specified.') # Runs customized training loop. logging.info('Training using customized training loop TF 2.0 with distrubuted' 'strategy.') trained_model = run_customized_training( strategy, bert_config, input_meta_data, FLAGS.model_dir, epochs, steps_per_epoch, FLAGS.steps_per_loop, eval_steps, warmup_steps, FLAGS.learning_rate, FLAGS.init_checkpoint, run_eagerly=FLAGS.run_eagerly) if FLAGS.model_export_path: model_saving_utils.export_bert_model( FLAGS.model_export_path, model=trained_model) return trained_model
def export_classifier(model_export_path, input_meta_data): """Exports a trained model as a `SavedModel` for inference. Args: model_export_path: a string specifying the path to the SavedModel directory. input_meta_data: dictionary containing meta data about input and model. Raises: Export path is not specified, got an empty string or None. """ if not model_export_path: raise ValueError('Export path is not specified: %s' % model_export_path) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) def _model_fn(): return bert_models.classifier_model( bert_config, tf.float32, input_meta_data['num_labels'], input_meta_data['max_seq_length'])[0] model_saving_utils.export_bert_model(model_export_path, model_fn=_model_fn, checkpoint_dir=FLAGS.model_dir)