def train(params, strategy, dataset=None): """Runs training.""" if not dataset: dataset = input_pipeline.get_input_dataset(FLAGS.train_file_pattern, FLAGS.train_batch_size, params, is_training=True, strategy=strategy) with strategy.scope(): model = models.create_model(FLAGS.model_type, params, init_checkpoint=FLAGS.init_checkpoint) opt = optimizer.create_optimizer(params) trainer = Trainer(model, params) trainer.compile(optimizer=opt, experimental_steps_per_execution=FLAGS.steps_per_loop) summary_dir = os.path.join(FLAGS.model_dir, "summaries") summary_callback = tf.keras.callbacks.TensorBoard( summary_dir, update_freq=max(100, FLAGS.steps_per_loop)) checkpoint = tf.train.Checkpoint(model=model, optimizer=opt, global_step=opt.iterations) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=FLAGS.model_dir, max_to_keep=10, step_counter=opt.iterations, checkpoint_interval=FLAGS.checkpoint_interval) if checkpoint_manager.restore_or_initialize(): logging.info("Training restored from the checkpoints in: %s", FLAGS.model_dir) checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) # Trains the model. steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval) epochs = FLAGS.train_steps // steps_per_epoch history = trainer.fit(x=dataset, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[summary_callback, checkpoint_callback], verbose=2) train_hist = history.history # Gets final loss from training. stats = dict(training_loss=float(train_hist["training_loss"][-1])) return stats
def run_keras_compile_fit(model_dir, strategy, model_fn, train_input_fn, eval_input_fn, loss_fn, metric_fn, init_checkpoint, epochs, steps_per_epoch, steps_per_loop, eval_steps, training_callbacks=True, custom_callbacks=None): """Runs BERT classifier model using Keras compile/fit API.""" with strategy.scope(): training_dataset = train_input_fn() evaluation_dataset = eval_input_fn() if eval_input_fn else None bert_model, sub_model = model_fn() optimizer = bert_model.optimizer if init_checkpoint: checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint.restore( init_checkpoint).assert_existing_objects_matched() if not isinstance(metric_fn, (list, tuple)): metric_fn = [metric_fn] bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[fn() for fn in metric_fn], experimental_steps_per_execution=steps_per_loop) summary_dir = os.path.join(model_dir, 'summaries') summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=None, step_counter=optimizer.iterations, checkpoint_interval=0) checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) if training_callbacks: if custom_callbacks is not None: custom_callbacks += [summary_callback, checkpoint_callback] else: custom_callbacks = [summary_callback, checkpoint_callback] history = bert_model.fit(x=training_dataset, validation_data=evaluation_dataset, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_steps=eval_steps, callbacks=custom_callbacks) stats = {'total_training_steps': steps_per_epoch * epochs} if 'loss' in history.history: stats['train_loss'] = history.history['loss'][-1] if 'val_accuracy' in history.history: stats['eval_metrics'] = history.history['val_accuracy'][-1] return bert_model, stats
def main(_) -> None: """Train and evaluate the Ranking model.""" params = train_utils.parse_configuration(FLAGS) mode = FLAGS.mode model_dir = FLAGS.model_dir if 'train' in FLAGS.mode: # Pure eval modes do not output yaml files. Otherwise continuous eval job # may race against the train job for writing the same file. train_utils.serialize_config(params, model_dir) if FLAGS.seed is not None: logging.info('Setting tf seed.') tf.random.set_seed(FLAGS.seed) task = RankingTask( params=params.task, optimizer_config=params.trainer.optimizer_config, logging_dir=model_dir, steps_per_execution=params.trainer.steps_per_loop, name='RankingTask') enable_tensorboard = params.trainer.callbacks.enable_tensorboard strategy = distribute_utils.get_distribution_strategy( distribution_strategy=params.runtime.distribution_strategy, all_reduce_alg=params.runtime.all_reduce_alg, num_gpus=params.runtime.num_gpus, tpu_address=params.runtime.tpu) with strategy.scope(): model = task.build_model() def get_dataset_fn(params): return lambda input_context: task.build_inputs(params, input_context) train_dataset = None if 'train' in mode: train_dataset = strategy.distribute_datasets_from_function( get_dataset_fn(params.task.train_data), options=tf.distribute.InputOptions(experimental_fetch_to_device=False)) validation_dataset = None if 'eval' in mode: validation_dataset = strategy.distribute_datasets_from_function( get_dataset_fn(params.task.validation_data), options=tf.distribute.InputOptions(experimental_fetch_to_device=False)) if params.trainer.use_orbit: with strategy.scope(): checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter( params, model_dir) trainer = RankingTrainer( config=params, task=task, model=model, optimizer=model.optimizer, train='train' in mode, evaluate='eval' in mode, train_dataset=train_dataset, validation_dataset=validation_dataset, checkpoint_exporter=checkpoint_exporter) train_lib.run_experiment( distribution_strategy=strategy, task=task, mode=mode, params=params, model_dir=model_dir, trainer=trainer) else: # Compile/fit checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer) latest_checkpoint = tf.train.latest_checkpoint(model_dir) if latest_checkpoint: checkpoint.restore(latest_checkpoint) logging.info('Loaded checkpoint %s', latest_checkpoint) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=model.optimizer.iterations, checkpoint_interval=params.trainer.checkpoint_interval) checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) time_callback = keras_utils.TimeHistory( params.task.train_data.global_batch_size, params.trainer.time_history.log_steps, logdir=model_dir if enable_tensorboard else None) callbacks = [checkpoint_callback, time_callback] if enable_tensorboard: tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=model_dir, update_freq=min(1000, params.trainer.validation_interval), profile_batch=FLAGS.profile_steps) callbacks.append(tensorboard_callback) num_epochs = (params.trainer.train_steps // params.trainer.validation_interval) current_step = model.optimizer.iterations.numpy() initial_epoch = current_step // params.trainer.validation_interval eval_steps = params.trainer.validation_steps if 'eval' in mode else None if mode in ['train', 'train_and_eval']: logging.info('Training started') history = model.fit( train_dataset, initial_epoch=initial_epoch, epochs=num_epochs, steps_per_epoch=params.trainer.validation_interval, validation_data=validation_dataset, validation_steps=eval_steps, callbacks=callbacks, ) model.summary() logging.info('Train history: %s', history.history) elif mode == 'eval': logging.info('Evaluation started') validation_output = model.evaluate(validation_dataset, steps=eval_steps) logging.info('Evaluation output: %s', validation_output) else: raise NotImplementedError('The mode is not implemented: %s' % mode)
def run_keras_compile_fit(model_dir, strategy, model_fn, train_input_fn, eval_input_fn, loss_fn, metric_fn, init_checkpoint, epochs, steps_per_epoch, steps_per_loop, eval_steps, monitor='val_loss', training_callbacks=True, custom_callbacks=None): """Runs BERT classifier model using Keras compile/fit API.""" # tf.config.set_soft_device_placement(True) # tf.debugging.experimental.enable_dump_debug_info( # '/tmp/my-tfdbg-dumps', tensor_debug_mode="FULL_HEALTH") with strategy.scope(): training_dataset = train_input_fn() evaluation_dataset = eval_input_fn() if eval_input_fn else None bert_model, sub_model = model_fn() optimizer = bert_model.optimizer if init_checkpoint: logging.info('Restore from {}'.format(init_checkpoint)) checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint.restore( init_checkpoint).assert_existing_objects_matched() if metric_fn and not isinstance(metric_fn, (list, tuple)): metric_fn = [metric_fn] bert_model.compile( optimizer=optimizer, loss=loss_fn, metrics=[fn() for fn in metric_fn] if metric_fn else None) # steps_per_loop这个是个坑,我在训练的时候没有设置,一直报init_value的错误,我还一直以为是模型的错误 # -_-!!! # experimental_steps_per_execution=steps_per_loop summary_dir = os.path.join(model_dir, 'summaries') summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=3, step_counter=optimizer.iterations, checkpoint_interval=0) checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager) #save best best_dir = os.path.join(model_dir, 'best/model') model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=best_dir, save_weights_only=True, monitor=monitor, mode='min' if ('loss' in monitor or 'error' in monitor) else 'max', save_best_only=True) if training_callbacks: if custom_callbacks is not None: custom_callbacks += [ summary_callback, checkpoint_callback, model_checkpoint_callback ] else: custom_callbacks = [ summary_callback, checkpoint_callback, model_checkpoint_callback ] logging.info('start to train') history = bert_model.fit(x=training_dataset, validation_data=evaluation_dataset, steps_per_epoch=steps_per_epoch, epochs=epochs, validation_steps=eval_steps, callbacks=custom_callbacks) stats = {'total_training_steps': steps_per_epoch * epochs} if 'loss' in history.history: stats['train_loss'] = history.history['loss'][-1] if 'val_accuracy' in history.history: stats['eval_metrics'] = history.history['val_accuracy'][-1] return bert_model, stats