def main(_):
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path,
                                                        max_seq_length,
                                                        FLAGS.train_batch_size,
                                                        is_training=True)
    eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path,
                                                       max_seq_length,
                                                       FLAGS.eval_batch_size,
                                                       is_training=False)

    albert_config = albert_configs.AlbertConfig.from_json_file(
        FLAGS.bert_config_file)
    if FLAGS.mode == 'train_and_eval':
        run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
                                     train_input_fn, eval_input_fn)
    elif FLAGS.mode == 'predict':
        predict(strategy, albert_config, input_meta_data, eval_input_fn)
    else:
        raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
    return
示例#2
0
def main(_):
    # Users should always run this script under TF 2.x
    assert tf.version.VERSION.startswith('2.')

    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
        input_meta_data = json.loads(reader.read().decode('utf-8'))

    if not FLAGS.model_dir:
        FLAGS.model_dir = '/tmp/bert20/'

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        num_gpus=FLAGS.num_gpus,
        tpu_address=FLAGS.tpu)
    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.train_data_path,
                                                        max_seq_length,
                                                        FLAGS.train_batch_size,
                                                        is_training=True)
    eval_input_fn = run_classifier_bert.get_dataset_fn(FLAGS.eval_data_path,
                                                       max_seq_length,
                                                       FLAGS.eval_batch_size,
                                                       is_training=False)

    albert_config = albert_configs.AlbertConfig.from_json_file(
        FLAGS.bert_config_file)
    run_classifier_bert.run_bert(strategy, input_meta_data, albert_config,
                                 train_input_fn, eval_input_fn)
示例#3
0
  def _run_bert_classifier(self, callbacks=None, use_ds=True):
    """Starts BERT classification task."""
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
      input_meta_data = json.loads(reader.read().decode('utf-8'))

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
    if self.num_steps_per_epoch:
      steps_per_epoch = self.num_steps_per_epoch
    else:
      train_data_size = input_meta_data['train_data_size']
      steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
    warmup_steps = int(epochs * steps_per_epoch * 0.1)
    eval_steps = int(
        math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
    if self.tpu:
      strategy = distribution_utils.get_distribution_strategy(
          distribution_strategy='tpu', tpu_address=self.tpu)
    else:
      strategy = distribution_utils.get_distribution_strategy(
          distribution_strategy='mirrored' if use_ds else 'off',
          num_gpus=self.num_gpus)

    steps_per_loop = 1

    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = run_classifier.get_dataset_fn(
        FLAGS.train_data_path,
        max_seq_length,
        FLAGS.train_batch_size,
        is_training=True)
    eval_input_fn = run_classifier.get_dataset_fn(
        FLAGS.eval_data_path,
        max_seq_length,
        FLAGS.eval_batch_size,
        is_training=False)
    run_classifier.run_bert_classifier(
        strategy,
        bert_config,
        input_meta_data,
        FLAGS.model_dir,
        epochs,
        steps_per_epoch,
        steps_per_loop,
        eval_steps,
        warmup_steps,
        FLAGS.learning_rate,
        FLAGS.init_checkpoint,
        train_input_fn,
        eval_input_fn,
        custom_callbacks=callbacks)
示例#4
0
train_data_path = "fine_data/MRPC_train.tf_record"
eval_data_path = "fine_data/MRPC_eval.tf_record"
input_meta_path = "fine_data/MRPC_meta_data"

bert_config_file = "cased_L-12_H-768_A-12/bert_config.json"
ckpt_path = 'cased_L-12_H-768_A-12/bert_model.ckpt.index'

with tf.io.gfile.GFile(input_meta_path, 'rb') as reader:
  input_meta_data = json.loads(reader.read().decode('utf-8'))

max_seq_length = input_meta_data['max_seq_length']
num_classes = input_meta_data['num_labels']
batch_size = 32
eval_batch_size = 32
train_input_fn = run_classifier.get_dataset_fn(train_data_path, max_seq_length, batch_size, is_training=True)
eval_input_fn = run_classifier.get_dataset_fn(eval_data_path, max_seq_length, eval_batch_size, is_training=False)

strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy='one_device', num_gpus=2)

with strategy.scope():
  training_dataset = train_input_fn()
  evaluation_dataset = eval_input_fn()
  bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)
  classifier_model, encoder = bert_models.classifier_model(
      bert_config, num_classes, max_seq_length)

  checkpoint = tf.train.Checkpoint(model=encoder)
  #checkpoint.restore(ckpt_path).assert_consumed()