def main(_): if not FLAGS.use_open_vocab: raise app.UsageError('Currently only use_open_vocab=True is supported') if FLAGS.train_insertion: model_dir = FLAGS.model_dir_insertion bert_config = configs.BertConfig.from_json_file( FLAGS.bert_config_insertion) else: model_dir = FLAGS.model_dir_tagging bert_config = configs.BertConfig.from_json_file( FLAGS.bert_config_tagging) if FLAGS.tpu is not None: cluster_resolver = distribute_utils.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.TPUStrategy(cluster_resolver) with strategy.scope(): return run_train(bert_config, FLAGS.max_seq_length, FLAGS.max_predictions_per_seq, model_dir, FLAGS.num_train_epochs, FLAGS.learning_rate, FLAGS.warmup_steps, 1.0, FLAGS.train_file, FLAGS.eval_file, FLAGS.train_batch_size, FLAGS.eval_batch_size, FLAGS.train_insertion, FLAGS.use_pointing, FLAGS.pointing_weight) else: return run_train(bert_config, FLAGS.max_seq_length, FLAGS.max_predictions_per_seq, model_dir, FLAGS.num_train_epochs, FLAGS.learning_rate, FLAGS.warmup_steps, 1.0, FLAGS.train_file, FLAGS.eval_file, FLAGS.train_batch_size, FLAGS.eval_batch_size, FLAGS.train_insertion, FLAGS.use_pointing, FLAGS.pointing_weight, FLAGS.mini_epochs_per_epoch)
def _init_strategy(self): """Initialize the distribution strategy (e.g. TPU/GPU/Mirrored).""" if self._strategy is None: if self._tpu is not None: resolver = distribute_utils.tpu_initialize(self._tpu) self._strategy = tf.distribute.experimental.TPUStrategy( resolver) elif self._distribution_strategy is None or self._distribution_strategy == 'default': self._strategy = tf.distribute.get_strategy() elif self._distribution_strategy == 'cpu': self._strategy = tf.distribute.OneDeviceStrategy( '/device:cpu:0') else: if self._distribution_strategy == 'mirrored': self._strategy = tf.distribute.MirroredStrategy() else: raise ValueError( f'Invalid distribution strategy="{self._distribution_strategy}"' )
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if not FLAGS.use_open_vocab: raise ValueError('Currently only use_open_vocab=True is supported') label_map = utils.read_label_map(FLAGS.label_map_file) bert_config_tagging = configs.BertConfig.from_json_file( FLAGS.bert_config_tagging) bert_config_insertion = configs.BertConfig.from_json_file( FLAGS.bert_config_insertion) if FLAGS.tpu is not None: cluster_resolver = distribute_utils.tpu_initialize(FLAGS.tpu) strategy = tf.distribute.TPUStrategy(cluster_resolver) with strategy.scope(): predictor = predict.FelixPredictor( bert_config_tagging=bert_config_tagging, bert_config_insertion=bert_config_insertion, model_tagging_filepath=FLAGS.model_tagging_filepath, model_insertion_filepath=FLAGS.model_insertion_filepath, vocab_file=FLAGS.vocab_file, label_map=label_map, sequence_length=FLAGS.max_seq_length, max_predictions=FLAGS.max_predictions_per_seq, do_lowercase=FLAGS.do_lower_case, use_open_vocab=FLAGS.use_open_vocab, is_pointing=FLAGS.use_pointing, insert_after_token=FLAGS.insert_after_token, special_glue_string_for_joining_sources=FLAGS .special_glue_string_for_joining_sources) else: predictor = predict.FelixPredictor( bert_config_tagging=bert_config_tagging, bert_config_insertion=bert_config_insertion, model_tagging_filepath=FLAGS.model_tagging_filepath, model_insertion_filepath=FLAGS.model_insertion_filepath, vocab_file=FLAGS.vocab_file, label_map_file=FLAGS.label_map_file, sequence_length=FLAGS.max_seq_length, max_predictions=FLAGS.max_predictions_per_seq, do_lowercase=FLAGS.do_lower_case, use_open_vocab=FLAGS.use_open_vocab, is_pointing=FLAGS.use_pointing, insert_after_token=FLAGS.insert_after_token, special_glue_string_for_joining_sources=FLAGS .special_glue_string_for_joining_sources) source_batch = [] target_batch = [] num_predicted = 0 with tf.io.gfile.GFile(FLAGS.predict_output_file, 'w') as writer: for source_batch, target_batch in batch_generator(): predicted_tags, predicted_inserts = predictor.predict_end_to_end_batch( source_batch) num_predicted += len(source_batch) logging.log_every_n(logging.INFO, f'{num_predicted} predicted.', 200) for source_input, target_output, predicted_tag, predicted_insert in zip( source_batch, target_batch, predicted_tags, predicted_inserts): writer.write(f'{source_input}\t{predicted_tag}\t{predicted_insert}\t' f'{target_output}\n')