Exemplo n.º 1
0
def run_bert(strategy, input_meta_data):
    """Run BERT training."""
    if FLAGS.mode == 'export_only':
        export_classifier(FLAGS.model_export_path, input_meta_data)
        return

    if FLAGS.mode != 'train_and_eval':
        raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
    # Enables XLA in Session Config. Should not be set for TPU.
    keras_utils.set_config_v2(FLAGS.enable_xla)

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    epochs = FLAGS.num_train_epochs
    train_data_size = input_meta_data['train_data_size']
    steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
    warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
    eval_steps = int(
        math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))

    if not strategy:
        raise ValueError('Distribution strategy has not been specified.')
    # Runs customized training loop.
    logging.info(
        'Training using customized training loop TF 2.0 with distrubuted'
        'strategy.')
    use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
    trained_model = run_customized_training(strategy,
                                            bert_config,
                                            input_meta_data,
                                            FLAGS.model_dir,
                                            epochs,
                                            steps_per_epoch,
                                            FLAGS.steps_per_loop,
                                            eval_steps,
                                            warmup_steps,
                                            FLAGS.learning_rate,
                                            FLAGS.init_checkpoint,
                                            use_remote_tpu=use_remote_tpu,
                                            run_eagerly=FLAGS.run_eagerly)

    if FLAGS.model_export_path:
        with tf.device(
                model_training_utils.get_primary_cpu_task(use_remote_tpu)):
            model_saving_utils.export_bert_model(FLAGS.model_export_path,
                                                 model=trained_model)
    return trained_model
Exemplo n.º 2
0
def export_squad(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  squad_model, _ = bert_models.squad_model(
      bert_config, input_meta_data['max_seq_length'], float_type=tf.float32)
  model_saving_utils.export_bert_model(
      model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
Exemplo n.º 3
0
def export_classifier(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  def _model_fn():
    return bert_models.classifier_model(bert_config, tf.float32,
                                        input_meta_data['num_labels'],
                                        input_meta_data['max_seq_length'])[0]

  model_saving_utils.export_bert_model(
      model_export_path, model_fn=_model_fn, checkpoint_dir=FLAGS.model_dir)