Esempio n. 1
0
    def _load_model(self, is_insertion=True):
        """Loads either an insertion or tagging model for inference."""
        def _get_fake_loss_fn():
            """"A fake loss function for inference."""
            def _fake_loss_fn(unused_labels, unused_losses, **unused_args):
                return 0.0

            return _fake_loss_fn

        model_filepath = None
        if is_insertion:
            model, _ = felix_models.get_insertion_model(
                self._bert_config_insertion,
                self._sequence_length,
                max_predictions_per_seq=self._max_predictions,
                is_training=False)
            model_filepath = self._model_insertion_filepath
        else:
            model, _ = felix_models.get_tagging_model(
                self._bert_config_tagging,
                self._sequence_length,
                use_pointing=self._is_pointing,
                pointing_weight=0,
                is_training=False)
            model_filepath = self._model_tagging_filepath
        if model_filepath:
            checkpoint = tf.train.Checkpoint(model=model)
            latest_checkpoint_file = tf.train.latest_checkpoint(model_filepath)
            checkpoint.restore(latest_checkpoint_file)
        model.compile(loss=_get_fake_loss_fn())
        if is_insertion:
            self._insertion_model = model
        else:
            self._tagging_model = model
Esempio n. 2
0
  def test_pretrain_model(self, use_insertion=True):
    if use_insertion:
      model, encoder = felix_models.get_insertion_model(
          self._bert_test_config, seq_length=5, max_predictions_per_seq=2)
    else:
      model, encoder = felix_models.get_tagging_model(
          self._bert_test_config, seq_length=5, use_pointing=True)
    self.assertIsInstance(model, tf.keras.Model)
    self.assertIsInstance(encoder, networks.BertEncoder)

    # model has one scalar output: loss value.
    self.assertEqual(model.output.shape.as_list(), [])

    # Expect two output from encoder: sequence and classification output.
    self.assertIsInstance(encoder.output, list)
    self.assertLen(encoder.output, 2)
    # shape should be [batch size, hidden_size]
    self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
Esempio n. 3
0
def run_train(bert_config,
              seq_length,
              max_predictions_per_seq,
              model_dir,
              epochs,
              initial_lr,
              warmup_steps,
              loss_scale,
              train_file,
              eval_file,
              train_batch_size,
              eval_batch_size,
              use_insertion=True,
              use_pointing=True,
              pointing_weight=1.0,
              mini_epochs_per_epoch=1):
    """Runs BERT pre-training using Keras `fit()` API."""

    mini_epochs_per_epoch = max(1, mini_epochs_per_epoch)

    if use_insertion:
        pretrain_model, bert_encoder = felix_models.get_insertion_model(
            bert_config, seq_length, max_predictions_per_seq)
    else:
        pretrain_model, bert_encoder = felix_models.get_tagging_model(
            bert_config,
            seq_length,
            use_pointing=use_pointing,
            pointing_weight=pointing_weight)
    # The original BERT model does not scale the loss by 1/num_replicas_in_sync.
    # It could be an accident. So, in order to use the same hyper parameter,
    # we do the same thing here.
    loss_fn = _get_loss_fn(loss_scale=loss_scale)

    steps_per_mini_epoch = int(FLAGS.num_train_examples / train_batch_size /
                               mini_epochs_per_epoch)
    eval_steps = max(1, int(FLAGS.num_eval_examples / eval_batch_size))

    optimizer = optimization.create_optimizer(
        init_lr=initial_lr,
        num_train_steps=steps_per_mini_epoch * mini_epochs_per_epoch * epochs,
        num_warmup_steps=warmup_steps)

    pretrain_model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        experimental_steps_per_execution=FLAGS.steps_per_loop)
    train_dataset = _get_input_data_fn(
        train_file,
        seq_length,
        max_predictions_per_seq,
        train_batch_size,
        is_training=True,
        use_insertion=use_insertion,
        use_pointing=use_pointing,
        use_weighted_labels=FLAGS.use_weighted_labels)
    eval_dataset = _get_input_data_fn(
        eval_file,
        seq_length,
        max_predictions_per_seq,
        eval_batch_size,
        is_training=False,
        use_insertion=use_insertion,
        use_pointing=use_pointing,
        use_weighted_labels=FLAGS.use_weighted_labels)

    latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
    if latest_checkpoint_file is not None:
        checkpoint = tf.train.Checkpoint(model=pretrain_model,
                                         optimizer=optimizer)
        # Since some model components(e.g. optimizer slot variables)
        # are loaded lazily for some components, we do not add any asserts
        # before model.call() is invoked.
        checkpoint.restore(latest_checkpoint_file)
        checkpoint_iteration = tf.keras.backend.get_value(
            pretrain_model.optimizer.iterations)
        current_mini_epoch = checkpoint_iteration // steps_per_mini_epoch
    else:
        # No latest checkpoint found so load a pre-trained checkpoint.
        if FLAGS.init_checkpoint:
            if _CHECKPOINT_FILE_NAME not in FLAGS.init_checkpoint:
                logging.info('Initializing from a BERT checkpoint...')
                checkpoint = tf.train.Checkpoint(model=bert_encoder)
                checkpoint.restore(
                    FLAGS.init_checkpoint).assert_existing_objects_matched()
            else:
                logging.info('Initializing from a Felix checkpoint...')
                # Initialize from a previously trained checkpoint.
                checkpoint = tf.train.Checkpoint(model=pretrain_model)
                checkpoint.restore(
                    FLAGS.init_checkpoint).assert_existing_objects_matched()
                # Reset the iteration number to have the learning rate adapt correctly.
                tf.keras.backend.set_value(pretrain_model.optimizer.iterations,
                                           0)

        checkpoint = tf.train.Checkpoint(model=pretrain_model,
                                         optimizer=optimizer)
        checkpoint_iteration = 0
        current_mini_epoch = 0

    logging.info('Starting training from iteration: %s.', checkpoint_iteration)
    summary_dir = os.path.join(model_dir, 'summaries')
    summary_cb = tf.keras.callbacks.TensorBoard(summary_dir, update_freq=1000)

    manager = tf.train.CheckpointManager(checkpoint,
                                         directory=model_dir,
                                         max_to_keep=FLAGS.keep_checkpoint_max,
                                         checkpoint_name=_CHECKPOINT_FILE_NAME)
    checkpoint_cb = CheckPointSaver(manager, current_mini_epoch)
    time_history_cb = keras_utils.TimeHistory(FLAGS.train_batch_size,
                                              FLAGS.log_steps)
    training_callbacks = [summary_cb, checkpoint_cb, time_history_cb]
    pretrain_model.fit(train_dataset,
                       initial_epoch=current_mini_epoch,
                       epochs=mini_epochs_per_epoch * epochs,
                       verbose=1,
                       steps_per_epoch=steps_per_mini_epoch,
                       validation_data=eval_dataset,
                       validation_steps=eval_steps,
                       callbacks=training_callbacks)