def run_customized_training(custom_callbacks=None): train_data_size = config.meta_data['train_data_size'] eval_data_size = config.meta_data['eval_data_size'] strategy = tf.distribute.MirroredStrategy() steps_per_epoch = int(train_data_size / FLAGS.batch_size / strategy.num_replicas_in_sync) eval_steps = int(math.ceil(eval_data_size / FLAGS.batch_size)) global_batch_size = FLAGS.batch_size * strategy.num_replicas_in_sync train_input_fn = functools.partial( input_pipeline.create_classifier_dataset, file_path=FLAGS.train_output_path, seq_length=config.max_seq_length, batch_size=global_batch_size, buffer_size=train_data_size, ) eval_input_fn = functools.partial(input_pipeline.create_classifier_dataset, file_path=FLAGS.eval_output_path, seq_length=config.max_seq_length, batch_size=global_batch_size, buffer_size=eval_data_size, is_training=False, drop_remainder=False) def _get_classifier_model(): core_model = get_model(config=config, float_type=tf.float32) core_model.optimizer = tf.keras.optimizers.Adam( learning_rate=FLAGS.learning_rate, beta_1=FLAGS.adam_beta1, beta_2=FLAGS.adam_beta2, epsilon=FLAGS.optimizer_adam_epsilon, ) return core_model loss_fn = get_loss_fn() return model_training_utils.run_customized_training_loop( strategy=strategy, model_fn=_get_classifier_model, loss_fn=loss_fn, model_dir=FLAGS.model_dir, steps_per_epoch=steps_per_epoch, steps_per_loop=FLAGS.steps_per_loop, epochs=FLAGS.epoch, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, eval_steps=eval_steps, init_checkpoint=FLAGS.init_checkpoint, custom_callbacks=custom_callbacks, )
def main(_): logging.set_verbosity(logging.INFO) if FLAGS.enable_xla: set_config_v2(FLAGS.enable_xla) strategy = None if FLAGS.strategy_type == "one": strategy = tf.distribute.OneDeviceStrategy("GPU:0") elif FLAGS.strategy_type == "mirror": strategy = tf.distribute.MirroredStrategy() else: raise ValueError( 'The distribution strategy type is not supported: %s' % FLAGS.strategy_type) with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: input_meta_data = json.loads(reader.read().decode('utf-8')) num_labels = input_meta_data["num_labels"] FLAGS.max_seq_length = input_meta_data["max_seq_length"] processor_type = input_meta_data['processor_type'] if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: raise ValueError( "At least one of `do_train`, `do_eval` or `do_predict' must be True." ) albert_config = AlbertConfig.from_json_file(FLAGS.albert_config_file) if FLAGS.max_seq_length > albert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the ALBERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, albert_config.max_position_embeddings)) tf.io.gfile.makedirs(FLAGS.output_dir) num_train_steps = None num_warmup_steps = None steps_per_epoch = None if FLAGS.do_train: len_train_examples = input_meta_data['train_data_size'] steps_per_epoch = int(len_train_examples / FLAGS.train_batch_size) num_train_steps = int(len_train_examples / FLAGS.train_batch_size * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) loss_multiplier = 1.0 / strategy.num_replicas_in_sync with strategy.scope(): model = get_model(albert_config=albert_config, max_seq_length=FLAGS.max_seq_length, num_labels=num_labels, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, loss_multiplier=loss_multiplier) model.summary() if FLAGS.do_train: logging.info("***** Running training *****") logging.info(" Num examples = %d", len_train_examples) logging.info(" Batch size = %d", FLAGS.train_batch_size) logging.info(" Num steps = %d", num_train_steps) train_input_fn = functools.partial(create_classifier_dataset, FLAGS.train_data_path, seq_length=FLAGS.max_seq_length, batch_size=FLAGS.train_batch_size, drop_remainder=False) eval_input_fn = functools.partial(create_classifier_dataset, FLAGS.eval_data_path, seq_length=FLAGS.max_seq_length, batch_size=FLAGS.eval_batch_size, is_training=False, drop_remainder=False) with strategy.scope(): summary_dir = os.path.join(FLAGS.output_dir, 'summaries') summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) checkpoint_path = os.path.join(FLAGS.output_dir, 'checkpoint') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_path, save_weights_only=True) custom_callbacks = [summary_callback, checkpoint_callback] def metric_fn(): if FLAGS.task_name.lower() == "sts": return tf.keras.metrics.MeanSquaredError(dtype=tf.float32) else: return tf.keras.metrics.SparseCategoricalAccuracy( dtype=tf.float32) if FLAGS.custom_training_loop: if FLAGS.task_name.lower() == "sts": loss_fn = get_loss_fn_v2(loss_factor=loss_multiplier) else: loss_fn = get_loss_fn(num_labels, loss_factor=loss_multiplier) model = run_customized_training_loop( strategy=strategy, model=model, loss_fn=loss_fn, model_dir=checkpoint_path, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, epochs=FLAGS.num_train_epochs, eval_input_fn=eval_input_fn, eval_steps=int(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size), metric_fn=metric_fn, custom_callbacks=custom_callbacks) else: training_dataset = train_input_fn() evaluation_dataset = eval_input_fn() model.fit(x=training_dataset, validation_data=evaluation_dataset, epochs=FLAGS.num_train_epochs, callbacks=custom_callbacks) if FLAGS.do_eval: len_eval_examples = input_meta_data['eval_data_size'] logging.info("***** Running evaluation *****") logging.info(" Num examples = %d", len_eval_examples) logging.info(" Batch size = %d", FLAGS.eval_batch_size) evaluation_dataset = eval_input_fn() with strategy.scope(): loss, accuracy = model.evaluate(evaluation_dataset) print(f"loss : {loss} , Accuracy : {accuracy}") if FLAGS.do_predict: logging.info("***** Running prediction*****") flags.mark_flag_as_required("input_data_dir") flags.mark_flag_as_required("predict_data_path") tokenizer = tokenization.FullTokenizer( vocab_file=None, spm_model_file=FLAGS.spm_model_file, do_lower_case=FLAGS.do_lower_case) processors = { "cola": classifier_data_lib.ColaProcessor, "sts": classifier_data_lib.StsbProcessor, "sst": classifier_data_lib.Sst2Processor, "mnli": classifier_data_lib.MnliProcessor, "qnli": classifier_data_lib.QnliProcessor, "qqp": classifier_data_lib.QqpProcessor, "rte": classifier_data_lib.RteProcessor, "mrpc": classifier_data_lib.MrpcProcessor, "wnli": classifier_data_lib.WnliProcessor, "xnli": classifier_data_lib.XnliProcessor, } task_name = FLAGS.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() predict_examples = processor.get_test_examples(FLAGS.input_data_dir) label_list = processor.get_labels() label_map = {i: label for i, label in enumerate(label_list)} classifier_data_lib.file_based_convert_examples_to_features( predict_examples, label_list, input_meta_data['max_seq_length'], tokenizer, FLAGS.predict_data_path) predict_input_fn = functools.partial( create_classifier_dataset, FLAGS.predict_data_path, seq_length=input_meta_data['max_seq_length'], batch_size=FLAGS.eval_batch_size, is_training=False, drop_remainder=False) prediction_dataset = predict_input_fn() with strategy.scope(): logits = model.predict(prediction_dataset) if FLAGS.task_name.lower() == "sts": predictions = logits probabilities = logits else: predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) probabilities = tf.nn.softmax(logits, axis=-1) output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv") with tf.io.gfile.GFile(output_predict_file, "w") as pred_writer,\ tf.io.gfile.GFile(output_submit_file, "w") as sub_writer: logging.info("***** Predict results *****") for (example, probability, prediction) in zip(predict_examples, probabilities, predictions): output_line = "\t".join( str(class_probability.numpy()) for class_probability in probability) + "\n" pred_writer.write(output_line) actual_label = label_map[int(prediction)] sub_writer.write( six.ensure_str(example.guid) + "\t" + actual_label + "\n")
def run_customized_training(strategy, albert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size): """Run BERT pretrain model training using low-level API.""" train_input_fn = functools.partial(get_pretrain_input_data, input_files, max_seq_length, max_predictions_per_seq, train_batch_size, strategy) with strategy.scope(): pretrain_model, core_model = albert_model.pretrain_model( albert_config, max_seq_length, max_predictions_per_seq) if FLAGS.init_checkpoint: logging.info( f"pre-trained weights loaded from {FLAGS.init_checkpoint}") pretrain_model.load_weights(FLAGS.init_checkpoint) learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=initial_lr, decay_steps=int(steps_per_epoch * epochs), end_learning_rate=0.0) if warmup_steps: learning_rate_fn = WarmUp(initial_learning_rate=initial_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=warmup_steps) if FLAGS.optimizer == "lamp": optimizer_fn = LAMB else: optimizer_fn = AdamWeightDecay optimizer = optimizer_fn( learning_rate=learning_rate_fn, weight_decay_rate=FLAGS.weight_decay, beta_1=0.9, beta_2=0.999, epsilon=FLAGS.adam_epsilon, exclude_from_weight_decay=['layer_norm', 'bias']) pretrain_model.optimizer = optimizer trained_model = run_customized_training_loop( strategy=strategy, model=pretrain_model, loss_fn=get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs) # Creates the BERT core model outside distribution strategy scope. _, core_model = albert_model.pretrain_model(albert_config, max_seq_length, max_predictions_per_seq) # Restores the core model from model checkpoints and save weights only # contains the core model. checkpoint = tf.train.Checkpoint(model=core_model) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) assert latest_checkpoint_file logging.info('Checkpoint file %s found and restoring from ' 'checkpoint', latest_checkpoint_file) status = checkpoint.restore(latest_checkpoint_file) status.assert_existing_objects_matched().expect_partial() core_model.save_weights(f"{model_dir}/tf2_model.h5") return trained_model
def train_squad(strategy, input_meta_data, custom_callbacks=None, run_eagerly=False): """Run bert squad training.""" if strategy: logging.info('Training using customized training loop with distribution' ' strategy.') # Enables XLA in Session Config. Should not be set for TPU. if FLAGS.enable_xla: set_config_v2(FLAGS.enable_xla) num_train_examples = input_meta_data['train_data_size'] max_seq_length = input_meta_data['max_seq_length'] num_train_steps = None num_warmup_steps = None steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size) num_train_steps = int( num_train_examples / FLAGS.train_batch_size * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) with strategy.scope(): albert_config = AlbertConfig.from_json_file(FLAGS.albert_config_file) if FLAGS.version_2_with_negative: model = get_model_v2(albert_config,input_meta_data['max_seq_length'], FLAGS.init_checkpoint, FLAGS.learning_rate, FLAGS.start_n_top, FLAGS.end_n_top,FLAGS.squad_dropout, num_train_steps, num_warmup_steps) else: model = get_model_v1(albert_config, input_meta_data['max_seq_length'], FLAGS.init_checkpoint, FLAGS.learning_rate, num_train_steps, num_warmup_steps) if FLAGS.version_2_with_negative: train_input_fn = functools.partial( input_pipeline.create_squad_dataset_v2, FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) else: train_input_fn = functools.partial( input_pipeline.create_squad_dataset, FLAGS.train_data_path, max_seq_length, FLAGS.train_batch_size, is_training=True) # The original BERT model does not scale the loss by # 1/num_replicas_in_sync. It could be an accident. So, in order to use # the same hyper parameter, we do the same thing here by keeping each # replica loss as it is. if FLAGS.version_2_with_negative: loss_fn = get_loss_fn_v2( loss_factor=1.0 / strategy.num_replicas_in_sync) else: loss_fn = get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync) trained_model = run_customized_training_loop( strategy=strategy, model=model, loss_fn=loss_fn, model_dir=FLAGS.model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, # steps_per_loop=steps_per_epoch, epochs=FLAGS.num_train_epochs, run_eagerly=run_eagerly, custom_callbacks=custom_callbacks)
def run_customized_training(strategy, albert_config, tinybert_config, max_seq_length, max_predictions_per_seq, model_dir, steps_per_epoch, steps_per_loop, epochs, initial_lr, warmup_steps, input_files, train_batch_size, use_mlm_loss): """Run BERT pretrain model training using low-level API.""" train_input_fn = functools.partial(get_pretrain_input_data, input_files, max_seq_length, max_predictions_per_seq, train_batch_size, strategy) with strategy.scope(): # albert, albert_encoder = albert_model.pretrain_model( # albert_config, max_seq_length, max_predictions_per_seq) train_model, albert, tinybert = tinybert_model.train_tinybert_model( tinybert_config, albert_config, max_seq_length, max_predictions_per_seq) albert.summary() tinybert.summary() train_model.summary() if FLAGS.init_checkpoint: logging.info( f"model pre-trained weights loaded from {FLAGS.init_checkpoint}" ) train_model.load_weights(FLAGS.init_checkpoint) learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=initial_lr, decay_steps=int(steps_per_epoch * epochs), end_learning_rate=0.0) if warmup_steps: learning_rate_fn = WarmUp(initial_learning_rate=initial_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=warmup_steps) if FLAGS.optimizer == "lamp": optimizer_fn = LAMB else: optimizer_fn = AdamWeightDecay optimizer = optimizer_fn( learning_rate=learning_rate_fn, weight_decay_rate=FLAGS.weight_decay, beta_1=0.9, beta_2=0.999, epsilon=FLAGS.adam_epsilon, exclude_from_weight_decay=['layer_norm', 'bias']) train_model.optimizer = optimizer # 注意这里的model_dir是albert和tinybert共享,需要修改 if FLAGS.do_train: trained_model = run_customized_training_loop( strategy=strategy, models=[albert, tinybert, train_model], model=train_model, albert=albert, tinybert=tinybert, start_wtih_trained_model=FLAGS.start_with_train_model, loss_fn=get_loss_fn(loss_factor=1.0 / strategy.num_replicas_in_sync), model_dir=model_dir, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, epochs=epochs, ) # Creates the BERT core model outside distribution strategy scope. training, albert, tinybert = tinybert_model.train_tinybert_model( tinybert_config, albert_config, max_seq_length, max_predictions_per_seq) # Restores the core model from model checkpoints and save weights only # contains the core model. # 在training的过程中会保存ckpt的模型文件,在训练结束后从ckpt读出模型再存储为h5文件 # 寻找albert模型文件 checkpoint_model = tf.train.Checkpoint(model=training) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) assert latest_checkpoint_file logging.info('Checkpoint file %s found and restoring from ' 'checkpoint', latest_checkpoint_file) status = checkpoint_model.restore(latest_checkpoint_file) status.assert_existing_objects_matched().expect_partial() # 寻找tinybert模型文件 # checkpoint_tinybert = tf.train.Checkpoint(model=tinybert) # latest_tinybert_checkpoint_file = tf.train.latest_checkpoint(tinybert_model_dir) # assert latest_tinybert_checkpoint_file # logging.info('Checkpoint_Tinybert file %s found and restoring from ' # 'checkpoint', latest_tinybert_checkpoint_file) # status_tinybert = checkpoint_albert.restore(latest_tinybert_checkpoint_file) # status_tinybert.assert_existing_objects_matched().expect_partial() # 创建存储文件 if not os.path.exists(model_dir + '/models/'): os.makedirs(model_dir + '/models/') albert.save_weights(f"{model_dir}/models/albert_model.h5") tinybert.save_weights(f"{model_dir}/models/tinybert_model.h5")