コード例 #1
0
def main(_):
    if not FLAGS.use_open_vocab:
        raise app.UsageError('Currently only use_open_vocab=True is supported')
    if FLAGS.train_insertion:
        model_dir = FLAGS.model_dir_insertion
        bert_config = configs.BertConfig.from_json_file(
            FLAGS.bert_config_insertion)
    else:
        model_dir = FLAGS.model_dir_tagging
        bert_config = configs.BertConfig.from_json_file(
            FLAGS.bert_config_tagging)
    if FLAGS.tpu is not None:
        cluster_resolver = distribute_utils.tpu_initialize(FLAGS.tpu)
        strategy = tf.distribute.TPUStrategy(cluster_resolver)
        with strategy.scope():
            return run_train(bert_config, FLAGS.max_seq_length,
                             FLAGS.max_predictions_per_seq, model_dir,
                             FLAGS.num_train_epochs, FLAGS.learning_rate,
                             FLAGS.warmup_steps, 1.0, FLAGS.train_file,
                             FLAGS.eval_file, FLAGS.train_batch_size,
                             FLAGS.eval_batch_size, FLAGS.train_insertion,
                             FLAGS.use_pointing, FLAGS.pointing_weight)
    else:
        return run_train(bert_config, FLAGS.max_seq_length,
                         FLAGS.max_predictions_per_seq, model_dir,
                         FLAGS.num_train_epochs, FLAGS.learning_rate,
                         FLAGS.warmup_steps, 1.0, FLAGS.train_file,
                         FLAGS.eval_file, FLAGS.train_batch_size,
                         FLAGS.eval_batch_size, FLAGS.train_insertion,
                         FLAGS.use_pointing, FLAGS.pointing_weight,
                         FLAGS.mini_epochs_per_epoch)
コード例 #2
0
ファイル: training.py プロジェクト: google-research/language
 def _init_strategy(self):
     """Initialize the distribution strategy (e.g. TPU/GPU/Mirrored)."""
     if self._strategy is None:
         if self._tpu is not None:
             resolver = distribute_utils.tpu_initialize(self._tpu)
             self._strategy = tf.distribute.experimental.TPUStrategy(
                 resolver)
         elif self._distribution_strategy is None or self._distribution_strategy == 'default':
             self._strategy = tf.distribute.get_strategy()
         elif self._distribution_strategy == 'cpu':
             self._strategy = tf.distribute.OneDeviceStrategy(
                 '/device:cpu:0')
         else:
             if self._distribution_strategy == 'mirrored':
                 self._strategy = tf.distribute.MirroredStrategy()
             else:
                 raise ValueError(
                     f'Invalid distribution strategy="{self._distribution_strategy}"'
                 )
コード例 #3
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  if not FLAGS.use_open_vocab:
    raise ValueError('Currently only use_open_vocab=True is supported')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  bert_config_tagging = configs.BertConfig.from_json_file(
      FLAGS.bert_config_tagging)
  bert_config_insertion = configs.BertConfig.from_json_file(
      FLAGS.bert_config_insertion)
  if FLAGS.tpu is not None:
    cluster_resolver = distribute_utils.tpu_initialize(FLAGS.tpu)
    strategy = tf.distribute.TPUStrategy(cluster_resolver)
    with strategy.scope():
      predictor = predict.FelixPredictor(
          bert_config_tagging=bert_config_tagging,
          bert_config_insertion=bert_config_insertion,
          model_tagging_filepath=FLAGS.model_tagging_filepath,
          model_insertion_filepath=FLAGS.model_insertion_filepath,
          vocab_file=FLAGS.vocab_file,
          label_map=label_map,
          sequence_length=FLAGS.max_seq_length,
          max_predictions=FLAGS.max_predictions_per_seq,
          do_lowercase=FLAGS.do_lower_case,
          use_open_vocab=FLAGS.use_open_vocab,
          is_pointing=FLAGS.use_pointing,
          insert_after_token=FLAGS.insert_after_token,
          special_glue_string_for_joining_sources=FLAGS
          .special_glue_string_for_joining_sources)
  else:
    predictor = predict.FelixPredictor(
        bert_config_tagging=bert_config_tagging,
        bert_config_insertion=bert_config_insertion,
        model_tagging_filepath=FLAGS.model_tagging_filepath,
        model_insertion_filepath=FLAGS.model_insertion_filepath,
        vocab_file=FLAGS.vocab_file,
        label_map_file=FLAGS.label_map_file,
        sequence_length=FLAGS.max_seq_length,
        max_predictions=FLAGS.max_predictions_per_seq,
        do_lowercase=FLAGS.do_lower_case,
        use_open_vocab=FLAGS.use_open_vocab,
        is_pointing=FLAGS.use_pointing,
        insert_after_token=FLAGS.insert_after_token,
        special_glue_string_for_joining_sources=FLAGS
        .special_glue_string_for_joining_sources)

  source_batch = []
  target_batch = []
  num_predicted = 0
  with tf.io.gfile.GFile(FLAGS.predict_output_file, 'w') as writer:
    for source_batch, target_batch in batch_generator():
      predicted_tags, predicted_inserts = predictor.predict_end_to_end_batch(
          source_batch)
      num_predicted += len(source_batch)
      logging.log_every_n(logging.INFO, f'{num_predicted} predicted.', 200)
      for source_input, target_output, predicted_tag, predicted_insert in zip(
          source_batch, target_batch, predicted_tags, predicted_inserts):
        writer.write(f'{source_input}\t{predicted_tag}\t{predicted_insert}\t'
                     f'{target_output}\n')