Example #1
0
def train(params, strategy, dataset=None):
    """Runs training."""

    if not dataset:
        dataset = input_pipeline.get_input_dataset(FLAGS.train_file_pattern,
                                                   FLAGS.train_batch_size,
                                                   params,
                                                   is_training=True,
                                                   strategy=strategy)

    with strategy.scope():
        model = models.create_model(FLAGS.model_type,
                                    params,
                                    init_checkpoint=FLAGS.init_checkpoint)
        opt = optimizer.create_optimizer(params)
        trainer = Trainer(model, params)

        trainer.compile(optimizer=opt,
                        experimental_steps_per_execution=FLAGS.steps_per_loop)
        summary_dir = os.path.join(FLAGS.model_dir, "summaries")
        summary_callback = tf.keras.callbacks.TensorBoard(
            summary_dir, update_freq=max(100, FLAGS.steps_per_loop))
        checkpoint = tf.train.Checkpoint(model=model,
                                         optimizer=opt,
                                         global_step=opt.iterations)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            directory=FLAGS.model_dir,
            max_to_keep=10,
            step_counter=opt.iterations,
            checkpoint_interval=FLAGS.checkpoint_interval)
        if checkpoint_manager.restore_or_initialize():
            logging.info("Training restored from the checkpoints in: %s",
                         FLAGS.model_dir)
        checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)

    # Trains the model.
    steps_per_epoch = min(FLAGS.train_steps, FLAGS.checkpoint_interval)
    epochs = FLAGS.train_steps // steps_per_epoch
    history = trainer.fit(x=dataset,
                          steps_per_epoch=steps_per_epoch,
                          epochs=epochs,
                          callbacks=[summary_callback, checkpoint_callback],
                          verbose=2)
    train_hist = history.history
    # Gets final loss from training.
    stats = dict(training_loss=float(train_hist["training_loss"][-1]))
    return stats
Example #2
0
def run_keras_compile_fit(model_dir,
                          strategy,
                          model_fn,
                          train_input_fn,
                          eval_input_fn,
                          loss_fn,
                          metric_fn,
                          init_checkpoint,
                          epochs,
                          steps_per_epoch,
                          steps_per_loop,
                          eval_steps,
                          training_callbacks=True,
                          custom_callbacks=None):
    """Runs BERT classifier model using Keras compile/fit API."""

    with strategy.scope():
        training_dataset = train_input_fn()
        evaluation_dataset = eval_input_fn() if eval_input_fn else None
        bert_model, sub_model = model_fn()
        optimizer = bert_model.optimizer

        if init_checkpoint:
            checkpoint = tf.train.Checkpoint(model=sub_model)
            checkpoint.restore(
                init_checkpoint).assert_existing_objects_matched()

        if not isinstance(metric_fn, (list, tuple)):
            metric_fn = [metric_fn]
        bert_model.compile(optimizer=optimizer,
                           loss=loss_fn,
                           metrics=[fn() for fn in metric_fn],
                           experimental_steps_per_execution=steps_per_loop)

        summary_dir = os.path.join(model_dir, 'summaries')
        summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
        checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            directory=model_dir,
            max_to_keep=None,
            step_counter=optimizer.iterations,
            checkpoint_interval=0)
        checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)

        if training_callbacks:
            if custom_callbacks is not None:
                custom_callbacks += [summary_callback, checkpoint_callback]
            else:
                custom_callbacks = [summary_callback, checkpoint_callback]

        history = bert_model.fit(x=training_dataset,
                                 validation_data=evaluation_dataset,
                                 steps_per_epoch=steps_per_epoch,
                                 epochs=epochs,
                                 validation_steps=eval_steps,
                                 callbacks=custom_callbacks)
        stats = {'total_training_steps': steps_per_epoch * epochs}
        if 'loss' in history.history:
            stats['train_loss'] = history.history['loss'][-1]
        if 'val_accuracy' in history.history:
            stats['eval_metrics'] = history.history['val_accuracy'][-1]
        return bert_model, stats
Example #3
0
def main(_) -> None:
  """Train and evaluate the Ranking model."""
  params = train_utils.parse_configuration(FLAGS)
  mode = FLAGS.mode
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  if FLAGS.seed is not None:
    logging.info('Setting tf seed.')
    tf.random.set_seed(FLAGS.seed)

  task = RankingTask(
      params=params.task,
      optimizer_config=params.trainer.optimizer_config,
      logging_dir=model_dir,
      steps_per_execution=params.trainer.steps_per_loop,
      name='RankingTask')

  enable_tensorboard = params.trainer.callbacks.enable_tensorboard

  strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu)

  with strategy.scope():
    model = task.build_model()

  def get_dataset_fn(params):
    return lambda input_context: task.build_inputs(params, input_context)

  train_dataset = None
  if 'train' in mode:
    train_dataset = strategy.distribute_datasets_from_function(
        get_dataset_fn(params.task.train_data),
        options=tf.distribute.InputOptions(experimental_fetch_to_device=False))

  validation_dataset = None
  if 'eval' in mode:
    validation_dataset = strategy.distribute_datasets_from_function(
        get_dataset_fn(params.task.validation_data),
        options=tf.distribute.InputOptions(experimental_fetch_to_device=False))

  if params.trainer.use_orbit:
    with strategy.scope():
      checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter(
          params, model_dir)
      trainer = RankingTrainer(
          config=params,
          task=task,
          model=model,
          optimizer=model.optimizer,
          train='train' in mode,
          evaluate='eval' in mode,
          train_dataset=train_dataset,
          validation_dataset=validation_dataset,
          checkpoint_exporter=checkpoint_exporter)

    train_lib.run_experiment(
        distribution_strategy=strategy,
        task=task,
        mode=mode,
        params=params,
        model_dir=model_dir,
        trainer=trainer)

  else:  # Compile/fit
    checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)

    latest_checkpoint = tf.train.latest_checkpoint(model_dir)
    if latest_checkpoint:
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)

    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=params.trainer.max_to_keep,
        step_counter=model.optimizer.iterations,
        checkpoint_interval=params.trainer.checkpoint_interval)
    checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)

    time_callback = keras_utils.TimeHistory(
        params.task.train_data.global_batch_size,
        params.trainer.time_history.log_steps,
        logdir=model_dir if enable_tensorboard else None)
    callbacks = [checkpoint_callback, time_callback]

    if enable_tensorboard:
      tensorboard_callback = tf.keras.callbacks.TensorBoard(
          log_dir=model_dir,
          update_freq=min(1000, params.trainer.validation_interval),
          profile_batch=FLAGS.profile_steps)
      callbacks.append(tensorboard_callback)

    num_epochs = (params.trainer.train_steps //
                  params.trainer.validation_interval)
    current_step = model.optimizer.iterations.numpy()
    initial_epoch = current_step // params.trainer.validation_interval

    eval_steps = params.trainer.validation_steps if 'eval' in mode else None

    if mode in ['train', 'train_and_eval']:
      logging.info('Training started')
      history = model.fit(
          train_dataset,
          initial_epoch=initial_epoch,
          epochs=num_epochs,
          steps_per_epoch=params.trainer.validation_interval,
          validation_data=validation_dataset,
          validation_steps=eval_steps,
          callbacks=callbacks,
      )
      model.summary()
      logging.info('Train history: %s', history.history)
    elif mode == 'eval':
      logging.info('Evaluation started')
      validation_output = model.evaluate(validation_dataset, steps=eval_steps)
      logging.info('Evaluation output: %s', validation_output)
    else:
      raise NotImplementedError('The mode is not implemented: %s' % mode)
Example #4
0
def run_keras_compile_fit(model_dir,
                          strategy,
                          model_fn,
                          train_input_fn,
                          eval_input_fn,
                          loss_fn,
                          metric_fn,
                          init_checkpoint,
                          epochs,
                          steps_per_epoch,
                          steps_per_loop,
                          eval_steps,
                          monitor='val_loss',
                          training_callbacks=True,
                          custom_callbacks=None):
    """Runs BERT classifier model using Keras compile/fit API."""
    # tf.config.set_soft_device_placement(True)
    # tf.debugging.experimental.enable_dump_debug_info(
    #   '/tmp/my-tfdbg-dumps', tensor_debug_mode="FULL_HEALTH")
    with strategy.scope():
        training_dataset = train_input_fn()
        evaluation_dataset = eval_input_fn() if eval_input_fn else None
        bert_model, sub_model = model_fn()

        optimizer = bert_model.optimizer

        if init_checkpoint:
            logging.info('Restore from {}'.format(init_checkpoint))
            checkpoint = tf.train.Checkpoint(model=sub_model)
            checkpoint.restore(
                init_checkpoint).assert_existing_objects_matched()

        if metric_fn and not isinstance(metric_fn, (list, tuple)):
            metric_fn = [metric_fn]
        bert_model.compile(
            optimizer=optimizer,
            loss=loss_fn,
            metrics=[fn() for fn in metric_fn] if metric_fn else None)
        # steps_per_loop这个是个坑,我在训练的时候没有设置,一直报init_value的错误,我还一直以为是模型的错误
        # -_-!!!
        # experimental_steps_per_execution=steps_per_loop

        summary_dir = os.path.join(model_dir, 'summaries')
        summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
        checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            directory=model_dir,
            max_to_keep=3,
            step_counter=optimizer.iterations,
            checkpoint_interval=0)
        checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)

        #save best
        best_dir = os.path.join(model_dir, 'best/model')
        model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=best_dir,
            save_weights_only=True,
            monitor=monitor,
            mode='min' if ('loss' in monitor or 'error' in monitor) else 'max',
            save_best_only=True)

        if training_callbacks:
            if custom_callbacks is not None:
                custom_callbacks += [
                    summary_callback, checkpoint_callback,
                    model_checkpoint_callback
                ]
            else:
                custom_callbacks = [
                    summary_callback, checkpoint_callback,
                    model_checkpoint_callback
                ]
        logging.info('start to train')
        history = bert_model.fit(x=training_dataset,
                                 validation_data=evaluation_dataset,
                                 steps_per_epoch=steps_per_epoch,
                                 epochs=epochs,
                                 validation_steps=eval_steps,
                                 callbacks=custom_callbacks)
        stats = {'total_training_steps': steps_per_epoch * epochs}
        if 'loss' in history.history:
            stats['train_loss'] = history.history['loss'][-1]
        if 'val_accuracy' in history.history:
            stats['eval_metrics'] = history.history['val_accuracy'][-1]
        return bert_model, stats