Esempio n. 1
0
    def run_classifier(self, model, epochs, train_ds, train_steps,
                       validation_ds, validation_steps, **kwargs):
        initial_lr = 0.1
        total_steps = train_steps * epochs
        warmup_steps = int(epochs * train_steps * 0.1)
        optimizer = optimization.create_optimizer(initial_lr, total_steps,
                                                  warmup_steps)
        model.compile(optimizer=optimizer,
                      loss='binary_crossentropy',
                      metrics=['accuracy'])

        hist = model.fit(train_ds,
                         steps_per_epoch=train_steps,
                         validation_data=validation_ds,
                         validation_steps=validation_steps,
                         epochs=epochs,
                         **kwargs)
        return hist
Esempio n. 2
0
 def _get_classifier_model():
   """Gets a classifier model."""
   classifier_model, core_model = (
       bert_models.classifier_model(
           bert_config,
           num_classes,
           max_seq_length))
   classifier_model.optimizer = optimization.create_optimizer(
       initial_lr, steps_per_epoch * epochs, warmup_steps)
   if FLAGS.fp16_implementation == 'graph_rewrite':
     # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
     # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
     # which will ensure tf.compat.v2.keras.mixed_precision and
     # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
     # up.
     classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
         classifier_model.optimizer)
   return classifier_model, core_model
Esempio n. 3
0
 def _get_classifier_model():
   """Gets a classifier model."""
   classifier_model, core_model = (
       bert_models.classifier_model(
           bert_config,
           num_classes,
           max_seq_length,
           hub_module_url=FLAGS.hub_module_url,
           hub_module_trainable=FLAGS.hub_module_trainable))
   optimizer = optimization.create_optimizer(initial_lr,
                                             steps_per_epoch * epochs,
                                             warmup_steps, FLAGS.end_lr,
                                             FLAGS.optimizer_type)
   classifier_model.optimizer = performance.configure_optimizer(
       optimizer,
       use_float16=common_flags.use_float16(),
       use_graph_rewrite=common_flags.use_graph_rewrite())
   return classifier_model, core_model
Esempio n. 4
0
  def _get_classifier_model():
    """Gets a classifier model."""
    bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
    classifier_model, core_model = (
        bert_models.classifier_model(
            bert_config,
            num_classes,
            input_meta_data['max_seq_length'],
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=False))

    epochs = FLAGS.num_train_epochs
    steps_per_epoch = int(70000 / 32)
    warmup_steps = int(2 * 70000 * 0.1 / 32)

    classifier_model.optimizer = optimization.create_optimizer(
        2e-5, steps_per_epoch * 2, warmup_steps)
    return classifier_model, core_model
Esempio n. 5
0
def bert(train_ds: tf.data.Dataset, epochs: int, no_classes) -> tf.keras.Model:
    """
    Build bert model

    :param tf.data.Dataset train_ds: training dataset
    :param int epochs: no epochs
    :param int no_classes: number of classes / output layer size
    :return: model object
    :rtype: tf.keras.Model
    """
    tfhub_handle_encoder = "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1"
    tfhub_handle_preprocess = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess,
                                         name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder,
                             trainable=True,
                             name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    x = tf.keras.layers.Dense(512, activation='relu')(net)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    output = tf.keras.layers.Dense(no_classes,
                                   activation='sigmoid',
                                   name="output0")(x)
    model = tf.keras.Model(text_input, output)
    loss = tf.keras.losses.BinaryCrossentropy()
    metrics = tf.metrics.BinaryAccuracy()
    model.compile()
    steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
    num_train_steps = steps_per_epoch * epochs
    num_warmup_steps = int(0.1 * num_train_steps)

    init_lr = 3e-5
    optimizer = optimization.create_optimizer(
        init_lr=init_lr,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        optimizer_type='adamw')
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    model.summary()
    return model
Esempio n. 6
0
  def run_classifier(self, train_ds, validation_ds, epochs, steps_per_epoch,
                     num_classes, **kwargs):
    """Creates classifier and runs the classifier training.

    Args:
      train_ds: tf.data.Dataset, training data to be fed in
        tf.keras.Model.fit().
      validation_ds: tf.data.Dataset, validation data to be fed in
        tf.keras.Model.fit().
      epochs: Integer, training epochs.
      steps_per_epoch: Integer or None. Total number of steps (batches of
        samples) before declaring one epoch finished and starting the next
        epoch. If `steps_per_epoch` is None, the epoch will run until the input
        dataset is exhausted.
      num_classes: Interger, number of classes.
      **kwargs: Other parameters used in the tf.keras.Model.fit().

    Returns:
      tf.keras.Model, the keras model that's already trained.
    """
    if steps_per_epoch is None:
      logging.info(
          'steps_per_epoch is None, use %d as the estimated steps_per_epoch',
          model_util.ESTIMITED_STEPS_PER_EPOCH)
      steps_per_epoch = model_util.ESTIMITED_STEPS_PER_EPOCH
    total_steps = steps_per_epoch * epochs
    warmup_steps = int(total_steps * 0.1)
    initial_lr = self.learning_rate

    with distribute_utils.get_strategy_scope(self.strategy):
      optimizer = optimization.create_optimizer(initial_lr, total_steps,
                                                warmup_steps)
      bert_model = self.create_model(num_classes, optimizer)

    for i in range(epochs):
      bert_model.fit(
          x=train_ds,
          initial_epoch=i,
          epochs=i + 1,
          validation_data=validation_ds,
          **kwargs)

    return bert_model
def train_evaluate(hparams):
    """Train and evaluate TensorFlow BERT sentiment classifier.
    Args:
      hparams(dict): A dictionary containing model training arguments.
    Returns:
      history(tf.keras.callbacks.History): Keras callback that records training event history.
    """
    dataset_dir = download_data(data_url=DATA_URL, 
                                local_data_dir=LOCAL_DATA_DIR)
    
    raw_train_ds, raw_val_ds, raw_test_ds = load_datasets(dataset_dir=dataset_dir,
                                                          hparams=hparams)
    
    train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)
    val_ds = raw_val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    test_ds = raw_test_ds.cache().prefetch(buffer_size=AUTOTUNE)     
    
    epochs = hparams['epochs']
    steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
    n_train_steps = steps_per_epoch * epochs
    n_warmup_steps = int(0.1 * n_train_steps)    
    
    optimizer = optimization.create_optimizer(init_lr=hparams['initial-learning-rate'],
                                              num_train_steps=n_train_steps,
                                              num_warmup_steps=n_warmup_steps,
                                              optimizer_type='adamw')    
    
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = build_text_classifier(hparams=hparams, optimizer=optimizer)
        logging.info(model.summary())
        
    history = model.fit(x=train_ds,
                        validation_data=val_ds,
                        epochs=epochs)  
    
    logging.info("Test accuracy: %s", model.evaluate(test_ds))

    # Export Keras model in TensorFlow SavedModel format.
    model.save(hparams['model-dir'])
    
    return history
Esempio n. 8
0
    def get_optimizer(optimizer: str, learning_rate: float, **kwargs):
        """

        :param optimizer:
        :param learning_rate:
        :param kwargs: any additional argument which might be needed
         two options includes: num_train_steps, num_warmup_steps for adamw optimizer
        :return:
        """
        if optimizer == 'adam':
            optimizer = Adam(learning_rate=learning_rate)
        elif optimizer == 'adamw':
            num_train_steps = kwargs['num_train_steps']
            num_warmup_steps = kwargs['num_warmup_steps']
            optimizer = optimization.create_optimizer(
                init_lr=learning_rate,
                num_train_steps=num_train_steps,
                num_warmup_steps=num_warmup_steps,
                optimizer_type='adamw')
        return optimizer
def _run_model_with_strategy(strategy, config, bert_config, dataset_fn):
    dataset_iterator = iter(
        strategy.experimental_distribute_datasets_from_function(dataset_fn))
    batch_size = int(config["batch_size"] / strategy.num_replicas_in_sync)
    with strategy.scope():
        model = nqg_model.Model(batch_size,
                                config,
                                bert_config,
                                training=True,
                                verbose=False)
        optimizer = optimization.create_optimizer(config["learning_rate"],
                                                  config["training_steps"],
                                                  config["warmup_steps"])
        train_for_n_steps_fn = training_utils.get_train_for_n_steps_fn(
            strategy, optimizer, model)
        mean_loss = train_for_n_steps_fn(
            dataset_iterator,
            tf.convert_to_tensor(config["steps_per_iteration"],
                                 dtype=tf.int32))
        return mean_loss
Esempio n. 10
0
    def train(self, train_input_fn, epochs, steps_per_epoch):
        """Run bert QA training."""
        warmup_steps = int(epochs * steps_per_epoch * 0.1)

        def _loss_fn(positions, logits):
            """Get losss function for QA model."""
            loss = tf.keras.losses.sparse_categorical_crossentropy(
                positions, logits, from_logits=True)
            return tf.reduce_mean(loss)

        with distribute_utils.get_strategy_scope(self.strategy):
            training_dataset = train_input_fn()
            bert_model = self.create_model()
            optimizer = optimization.create_optimizer(self.learning_rate,
                                                      steps_per_epoch * epochs,
                                                      warmup_steps)

            bert_model.compile(optimizer=optimizer,
                               loss=_loss_fn,
                               loss_weights=[0.5, 0.5])

        summary_dir = os.path.join(self.model_dir, 'summaries')
        summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
        checkpoint_path = os.path.join(self.model_dir, 'checkpoint')
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            checkpoint_path, save_weights_only=True)

        if not bert_model.trainable_variables:
            tf.compat.v1.logging.warning(
                'Trainable variables in the model are empty.')
            return bert_model

        bert_model.fit(x=training_dataset,
                       steps_per_epoch=steps_per_epoch,
                       epochs=epochs,
                       callbacks=[summary_callback, checkpoint_callback])

        return bert_model
Esempio n. 11
0
 def _get_squad_model():
     """Get Squad model and optimizer."""
     squad_model, core_model = bert_models.squad_model(
         bert_config, max_seq_length, hub_module_url=FLAGS.hub_module_url)
     squad_model.optimizer = optimization.create_optimizer(
         FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
     if use_float16:
         # Wraps optimizer with a LossScaleOptimizer. This is done automatically
         # in compile() with the "mixed_float16" policy, but since we do not call
         # compile(), we must wrap the optimizer manually.
         squad_model.optimizer = (
             tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                 squad_model.optimizer,
                 loss_scale=common_flags.get_loss_scale()))
     if FLAGS.fp16_implementation == 'graph_rewrite':
         # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
         # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
         # which will ensure tf.compat.v2.keras.mixed_precision and
         # tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
         # up.
         squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
             squad_model.optimizer)
     return squad_model, core_model
Esempio n. 12
0
            y_train_c = tf.keras.utils.to_categorical(y_train)
            y_test_c = tf.keras.utils.to_categorical(y_test)

            # BERT model
            clf = build_classifier_model()
            loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
            metrics = tf.metrics.CategoricalAccuracy()
            epochs = 5
            steps_per_epoch = np.sqrt(X.shape[0])
            num_train_steps = steps_per_epoch * epochs
            num_warmup_steps = int(0.1*num_train_steps)

            init_lr = 3e-5
            optimizer = optimization.create_optimizer(
                init_lr=init_lr,
                num_train_steps=num_train_steps,
                num_warmup_steps=num_warmup_steps,
                optimizer_type='adamw'
            )

            clf.compile(optimizer=optimizer, loss=loss, metrics=metrics)

            history = clf.fit(
                x=X_train,
                y = y_train_c,
                # validation_data=None,
                epochs=epochs
            )

            proba = clf.predict(X_test)
            # pred = np.argmax(proba, axis=1)
            # score = balanced_accuracy_score(y_test, pred)
def main(args: argparse.Namespace) -> None:
    """Main entrypoint for the script."""
    set_seed(args.seed)

    # Create output directory
    args.output_dir.mkdir(parents=True, exist_ok=True)
    run_name = args.run_name or args.dataset_directory.stem
    logdir = args.output_dir / get_next_run_id(args.output_dir, run_name)

    # Load datasets
    train_dataset, validation_dataset, test_dataset, num_classes = load_dataset(
        args.dataset_directory,
        validation_split=args.validation_split,
        batch_size=args.batch_size,
        train_folder_name=args.train_folder_name,
        test_folder_name=args.test_folder_name
    )

    # Build the model
    model = build_classifier_model(
        num_classes,
        dropout_rate=args.dropout_rate,
        bert_preprocess_handle=args.bert_preprocess_handle,
        bert_model_handle=args.bert_model_handle
    )

    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    metrics = tf.metrics.SparseCategoricalAccuracy()

    steps_per_epoch = tf.data.experimental.cardinality(train_dataset).numpy()
    total_train_steps = steps_per_epoch * args.epochs
    warmup_steps = int(0.1 * total_train_steps)

    # Load the optimizer
    optimizer = optimization.create_optimizer(
        init_lr=args.initial_lr,
        num_train_steps=total_train_steps,
        num_warmup_steps=warmup_steps,
        optimizer_type='adamw'
    )

    # Compile the model with the optimizer, loss, and metrics
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # Create training callbacks
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=logdir / 'model.{epoch:02d}-{val_loss:.2f}.h5',
        monitor='val_loss'
    )

    # Train the model
    logger.info('Starting training with {} (for {} epochs).'.format(
        args.bert_model_handle, args.epochs
    ))

    model.fit(
        x=train_dataset,
        validation_data=validation_dataset,
        epochs=args.epochs,
        callbacks=[
            tensorboard_callback,
            checkpoint_callback
        ]
    )

    # Evaluate the model
    logger.info('Evaluating the model on the testing dataset')
    loss, accuracy = model.evaluate(test_dataset)
    logger.info('Loss: {}'.format(loss))
    logger.info('Accuracy: {}'.format(accuracy))

    # Save final model
    model.save(logdir / 'model_final')
Esempio n. 14
0
        'input_mask': input_mask,
        'input_type_ids': segment_ids
    },
                           outputs=output)

    epochs = 3
    batch_size = 4
    eval_batch_size = batch_size

    train_data_size = len(y_train)
    steps_per_epoch = int(train_data_size / batch_size)
    num_train_steps = steps_per_epoch * epochs
    warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)

    optimizer = opti.create_optimizer(1e-3,
                                      num_train_steps=num_train_steps,
                                      num_warmup_steps=warmup_steps)

    model.compile(optimizer=optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    model.summary()
    history = model.fit(X_train,
                        y_train,
                        epochs=epochs,
                        batch_size=batch_size,
                        validation_data=(X_val, y_val),
                        verbose=1)
    model.save(os.path.join(".", "offensive_model_final"))
    print(model.evaluate(X_test, y_test))
Esempio n. 15
0
def run_train(bert_config,
              seq_length,
              max_predictions_per_seq,
              model_dir,
              epochs,
              initial_lr,
              warmup_steps,
              loss_scale,
              train_file,
              eval_file,
              train_batch_size,
              eval_batch_size,
              use_insertion=True,
              use_pointing=True,
              pointing_weight=1.0,
              mini_epochs_per_epoch=1):
    """Runs BERT pre-training using Keras `fit()` API."""

    mini_epochs_per_epoch = max(1, mini_epochs_per_epoch)

    if use_insertion:
        pretrain_model, bert_encoder = felix_models.get_insertion_model(
            bert_config, seq_length, max_predictions_per_seq)
    else:
        pretrain_model, bert_encoder = felix_models.get_tagging_model(
            bert_config,
            seq_length,
            use_pointing=use_pointing,
            pointing_weight=pointing_weight)
    # The original BERT model does not scale the loss by 1/num_replicas_in_sync.
    # It could be an accident. So, in order to use the same hyper parameter,
    # we do the same thing here.
    loss_fn = _get_loss_fn(loss_scale=loss_scale)

    steps_per_mini_epoch = int(FLAGS.num_train_examples / train_batch_size /
                               mini_epochs_per_epoch)
    eval_steps = max(1, int(FLAGS.num_eval_examples / eval_batch_size))

    optimizer = optimization.create_optimizer(
        init_lr=initial_lr,
        num_train_steps=steps_per_mini_epoch * mini_epochs_per_epoch * epochs,
        num_warmup_steps=warmup_steps)

    pretrain_model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        experimental_steps_per_execution=FLAGS.steps_per_loop)
    train_dataset = _get_input_data_fn(
        train_file,
        seq_length,
        max_predictions_per_seq,
        train_batch_size,
        is_training=True,
        use_insertion=use_insertion,
        use_pointing=use_pointing,
        use_weighted_labels=FLAGS.use_weighted_labels)
    eval_dataset = _get_input_data_fn(
        eval_file,
        seq_length,
        max_predictions_per_seq,
        eval_batch_size,
        is_training=False,
        use_insertion=use_insertion,
        use_pointing=use_pointing,
        use_weighted_labels=FLAGS.use_weighted_labels)

    latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
    if latest_checkpoint_file is not None:
        checkpoint = tf.train.Checkpoint(model=pretrain_model,
                                         optimizer=optimizer)
        # Since some model components(e.g. optimizer slot variables)
        # are loaded lazily for some components, we do not add any asserts
        # before model.call() is invoked.
        checkpoint.restore(latest_checkpoint_file)
        checkpoint_iteration = tf.keras.backend.get_value(
            pretrain_model.optimizer.iterations)
        current_mini_epoch = checkpoint_iteration // steps_per_mini_epoch
    else:
        # No latest checkpoint found so load a pre-trained checkpoint.
        if FLAGS.init_checkpoint:
            if _CHECKPOINT_FILE_NAME not in FLAGS.init_checkpoint:
                logging.info('Initializing from a BERT checkpoint...')
                checkpoint = tf.train.Checkpoint(model=bert_encoder)
                checkpoint.restore(
                    FLAGS.init_checkpoint).assert_existing_objects_matched()
            else:
                logging.info('Initializing from a Felix checkpoint...')
                # Initialize from a previously trained checkpoint.
                checkpoint = tf.train.Checkpoint(model=pretrain_model)
                checkpoint.restore(
                    FLAGS.init_checkpoint).assert_existing_objects_matched()
                # Reset the iteration number to have the learning rate adapt correctly.
                tf.keras.backend.set_value(pretrain_model.optimizer.iterations,
                                           0)

        checkpoint = tf.train.Checkpoint(model=pretrain_model,
                                         optimizer=optimizer)
        checkpoint_iteration = 0
        current_mini_epoch = 0

    logging.info('Starting training from iteration: %s.', checkpoint_iteration)
    summary_dir = os.path.join(model_dir, 'summaries')
    summary_cb = tf.keras.callbacks.TensorBoard(summary_dir, update_freq=1000)

    manager = tf.train.CheckpointManager(checkpoint,
                                         directory=model_dir,
                                         max_to_keep=FLAGS.keep_checkpoint_max,
                                         checkpoint_name=_CHECKPOINT_FILE_NAME)
    checkpoint_cb = CheckPointSaver(manager, current_mini_epoch)
    time_history_cb = keras_utils.TimeHistory(FLAGS.train_batch_size,
                                              FLAGS.log_steps)
    training_callbacks = [summary_cb, checkpoint_cb, time_history_cb]
    pretrain_model.fit(train_dataset,
                       initial_epoch=current_mini_epoch,
                       epochs=mini_epochs_per_epoch * epochs,
                       verbose=1,
                       steps_per_epoch=steps_per_mini_epoch,
                       validation_data=eval_dataset,
                       validation_steps=eval_steps,
                       callbacks=training_callbacks)
Esempio n. 16
0
with strategy.scope():
  training_dataset = train_input_fn()
  evaluation_dataset = eval_input_fn()
  bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)
  classifier_model, encoder = bert_models.classifier_model(
      bert_config, num_classes, max_seq_length)

  checkpoint = tf.train.Checkpoint(model=encoder)
  #checkpoint.restore(ckpt_path).assert_consumed()

  epochs = 3
  train_data_size = input_meta_data['train_data_size']
  eval_data_size = input_meta_data['eval_data_size']
  steps_per_epoch = int(train_data_size / batch_size)
  warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)
  optimizer = optimization.create_optimizer(
      2e-5, num_train_steps=steps_per_epoch * epochs, num_warmup_steps=warmup_steps)

  def metric_fn():
    return tf.keras.metrics.SparseCategoricalAccuracy(
        'test_accuracy', dtype=tf.float32)

  classifier_model.compile(optimizer=optimizer,
                           loss=run_classifier.get_loss_fn(num_classes=2),
                           metrics=[metric_fn()])
  classifier_model.fit(
        x=training_dataset,
        validation_data=evaluation_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_steps=int(eval_data_size / eval_batch_size))
Esempio n. 17
0

    train_dataset = create_squad_dataset(
        cf.INPUTS_FILE_DEV,
        input_meta_data['max_seq_length'], # 384
        cf.BATCH_SIZE,
        is_training=True
    )

    train_dataset_light = train_dataset.take(cf.NB_BATCHES_TRAIN)

    bert_squad = BERTSquad()

    optimizer = optimization.create_optimizer(
        init_lr=cf.INIT_LR,
        num_train_steps=cf.NB_BATCHES_TRAIN,
        num_warmup_steps=cf.WARMUP_STEPS
    )

    train_loss = tf.keras.metrics.Mean(name="train_loss")

    bert_squad.compile(
        optimizer,
        squad_loss_fn
    )

    ckpt = tf.train.Checkpoint(bert_squad=bert_squad)

    ckpt_manager = tf.train.CheckpointManager(ckpt, cf.CHECKPOINT_PATH, max_to_keep=1)

    if ckpt_manager.latest_checkpoint:
Esempio n. 18
0
def train_model(strategy):
    """Run model training."""
    config = config_utils.json_file_to_dict(FLAGS.config)
    dataset_fn = input_utils.get_dataset_fn(FLAGS.input, config)

    writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.model_dir, "train"))

    dataset_iterator = iter(
        strategy.experimental_distribute_datasets_from_function(dataset_fn))

    bert_config = configs.BertConfig.from_json_file(
        os.path.join(FLAGS.bert_dir, "bert_config.json"))
    logging.info("Loaded BERT config: %s", bert_config.to_dict())
    batch_size = int(config["batch_size"] / strategy.num_replicas_in_sync)
    logging.info("num_replicas: %s.", strategy.num_replicas_in_sync)
    logging.info("per replica batch_size: %s.", batch_size)

    with strategy.scope():
        model = nqg_model.Model(batch_size,
                                config,
                                bert_config,
                                training=True,
                                verbose=FLAGS.verbose)
        optimizer = optimization.create_optimizer(config["learning_rate"],
                                                  config["training_steps"],
                                                  config["warmup_steps"])
        train_for_n_steps_fn = training_utils.get_train_for_n_steps_fn(
            strategy, optimizer, model)

        if FLAGS.init_bert_checkpoint:
            bert_checkpoint = tf.train.Checkpoint(model=model.bert_encoder)
            bert_checkpoint_path = os.path.join(FLAGS.bert_dir,
                                                "bert_model.ckpt")
            logging.info("Restoring bert checkpoint: %s", bert_checkpoint_path)
            logging.info("Bert vars: %s",
                         model.bert_encoder.trainable_variables)
            logging.info("Checkpoint vars: %s",
                         tf.train.list_variables(bert_checkpoint_path))
            status = bert_checkpoint.restore(bert_checkpoint_path)
            status.assert_existing_objects_matched()

        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        current_step = 0

        if FLAGS.restore_checkpoint:
            latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
            # TODO(petershaw): This is a hacky way to read current step.
            current_step = int(latest_checkpoint.split("-")[-2])
            logging.info("Restoring %s at step %s.", latest_checkpoint,
                         current_step)
            status = checkpoint.restore(latest_checkpoint)
            status.assert_existing_objects_matched()

        with writer.as_default():
            while current_step < config["training_steps"]:
                logging.info("current_step: %s.", current_step)
                mean_loss = train_for_n_steps_fn(
                    dataset_iterator,
                    tf.convert_to_tensor(config["steps_per_iteration"],
                                         dtype=tf.int32))
                tf.summary.scalar("loss", mean_loss, step=current_step)
                current_step += config["steps_per_iteration"]

                if current_step and current_step % config[
                        "save_checkpoint_every"] == 0:
                    checkpoint_prefix = os.path.join(FLAGS.model_dir,
                                                     "ckpt-%s" % current_step)
                    logging.info("Saving checkpoint to %s.", checkpoint_prefix)
                    checkpoint.save(file_prefix=checkpoint_prefix)