def _export_bert_tfhub(self):
        bert_config = configs.BertConfig(vocab_size=30522,
                                         hidden_size=16,
                                         intermediate_size=32,
                                         max_position_embeddings=128,
                                         num_attention_heads=2,
                                         num_hidden_layers=4)
        encoder = export_tfhub_lib.get_bert_encoder(bert_config)
        model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")

        checkpoint = tf.train.Checkpoint(encoder=encoder)
        checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
        model_checkpoint_path = tf.train.latest_checkpoint(
            model_checkpoint_dir)

        vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
        with tf.io.gfile.GFile(vocab_file, "w") as f:
            f.write("dummy content")

        export_path = os.path.join(self.get_temp_dir(), "hub")
        export_tfhub_lib.export_model(
            export_path,
            bert_config=bert_config,
            encoder_config=None,
            model_checkpoint_path=model_checkpoint_path,
            vocab_file=vocab_file,
            do_lower_case=True,
            with_mlm=False)
        return export_path
Пример #2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)

    if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
        raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
                         "can be specified, but got %s and %s." %
                         (FLAGS.vocab_file, FLAGS.sp_model_file))
    do_lower_case = export_tfhub_lib.get_do_lower_case(FLAGS.do_lower_case,
                                                       FLAGS.vocab_file,
                                                       FLAGS.sp_model_file)

    if FLAGS.export_type in ("model", "model_with_mlm"):
        if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
            raise ValueError(
                "Exactly one of `bert_config_file` and "
                "`encoder_config_file` can be specified, but got "
                "%s and %s." %
                (FLAGS.bert_config_file, FLAGS.encoder_config_file))
        if FLAGS.bert_config_file:
            bert_config = configs.BertConfig.from_json_file(
                FLAGS.bert_config_file)
            encoder_config = None
        else:
            bert_config = None
            encoder_config = encoders.EncoderConfig()
            encoder_config = hyperparams.override_params_dict(
                encoder_config, FLAGS.encoder_config_file, is_strict=True)
        export_tfhub_lib.export_model(
            FLAGS.export_path,
            bert_config=bert_config,
            encoder_config=encoder_config,
            model_checkpoint_path=FLAGS.model_checkpoint_path,
            vocab_file=FLAGS.vocab_file,
            sp_model_file=FLAGS.sp_model_file,
            do_lower_case=do_lower_case,
            with_mlm=FLAGS.export_type == "model_with_mlm",
            copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)

    elif FLAGS.export_type == "preprocessing":
        export_tfhub_lib.export_preprocessing(
            FLAGS.export_path,
            vocab_file=FLAGS.vocab_file,
            sp_model_file=FLAGS.sp_model_file,
            do_lower_case=do_lower_case,
            default_seq_length=FLAGS.default_seq_length,
            tokenize_with_offsets=FLAGS.tokenize_with_offsets,
            experimental_disable_assert=FLAGS.
            experimental_disable_assert_in_preprocessing)

    else:
        raise app.UsageError("Unknown value '%s' for flag --export_type" %
                             FLAGS.export_type)
Пример #3
0
    def test_copy_pooler_dense_to_encoder(self):
        encoder_config = encoders.EncoderConfig(
            type="bert",
            bert=encoders.BertEncoderConfig(hidden_size=24,
                                            intermediate_size=48,
                                            num_layers=2))
        cls_heads = [
            layers.ClassificationHead(inner_dim=24,
                                      num_classes=2,
                                      name="next_sentence")
        ]
        encoder = encoders.build_encoder(encoder_config)
        pretrainer = models.BertPretrainerV2(
            encoder_network=encoder,
            classification_heads=cls_heads,
            mlm_activation=tf_utils.get_activation(
                encoder_config.get().hidden_activation))
        # Makes sure the pretrainer variables are created.
        _ = pretrainer(pretrainer.inputs)
        checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
        model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
        checkpoint.save(os.path.join(model_checkpoint_dir, "test"))

        vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
            self.get_temp_dir(), use_sp_model=True)
        export_path = os.path.join(self.get_temp_dir(), "hub")
        export_tfhub_lib.export_model(
            export_path=export_path,
            encoder_config=encoder_config,
            model_checkpoint_path=tf.train.latest_checkpoint(
                model_checkpoint_dir),
            with_mlm=True,
            copy_pooler_dense_to_encoder=True,
            vocab_file=vocab_file,
            sp_model_file=sp_model_file,
            do_lower_case=True)
        # Restores a hub KerasLayer.
        hub_layer = hub.KerasLayer(export_path, trainable=True)
        dummy_ids = np.zeros((2, 10), dtype=np.int32)
        input_dict = dict(input_word_ids=dummy_ids,
                          input_mask=dummy_ids,
                          input_type_ids=dummy_ids)
        hub_pooled_output = hub_layer(input_dict)["pooled_output"]
        encoder_outputs = encoder(input_dict)
        # Verify that hub_layer's pooled_output is the same as the output of next
        # sentence prediction's dense layer.
        pretrained_pooled_output = cls_heads[0].dense(
            (encoder_outputs["sequence_output"][:, 0, :]))
        self.assertAllClose(hub_pooled_output, pretrained_pooled_output)
        # But the pooled_output between encoder and hub_layer are not the same.
        encoder_pooled_output = encoder_outputs["pooled_output"]
        self.assertNotAllClose(hub_pooled_output, encoder_pooled_output)
Пример #4
0
    def test_export_model_with_mlm(self, use_bert):
        # Create the encoder and export it.
        hidden_size = 16
        num_hidden_layers = 2
        bert_config, encoder_config = _get_bert_config_or_encoder_config(
            use_bert, hidden_size, num_hidden_layers)
        bert_model, pretrainer = export_tfhub_lib._create_model(
            bert_config=bert_config,
            encoder_config=encoder_config,
            with_mlm=True)
        self.assertEmpty(
            _find_lambda_layers(bert_model),
            "Lambda layers are non-portable since they serialize Python bytecode."
        )
        bert_model_with_mlm = bert_model.mlm
        model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")

        checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)

        checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
        model_checkpoint_path = tf.train.latest_checkpoint(
            model_checkpoint_dir)

        vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
            self.get_temp_dir(), use_sp_model=not use_bert)
        export_path = os.path.join(self.get_temp_dir(), "hub")
        export_tfhub_lib.export_model(
            export_path=export_path,
            bert_config=bert_config,
            encoder_config=encoder_config,
            model_checkpoint_path=model_checkpoint_path,
            with_mlm=True,
            vocab_file=vocab_file,
            sp_model_file=sp_model_file,
            do_lower_case=True)

        # Restore the exported model.
        hub_layer = hub.KerasLayer(export_path, trainable=True)

        # Check legacy tokenization data.
        if use_bert:
            self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
            self.assertEqual("dummy content",
                             _read_asset(hub_layer.resolved_object.vocab_file))
            self.assertFalse(
                hasattr(hub_layer.resolved_object, "sp_model_file"))
        else:
            self.assertFalse(
                hasattr(hub_layer.resolved_object, "do_lower_case"))
            self.assertFalse(hasattr(hub_layer.resolved_object, "vocab_file"))
            self.assertEqual(
                "dummy content",
                _read_asset(hub_layer.resolved_object.sp_model_file))

        # Check restored weights.
        # Note that we set `_auto_track_sub_layers` to False when exporting the
        # SavedModel, so hub_layer has the same number of weights as bert_model;
        # otherwise, hub_layer will have extra weights from its `mlm` subobject.
        self.assertEqual(len(bert_model.trainable_weights),
                         len(hub_layer.trainable_weights))
        for source_weight, hub_weight in zip(bert_model.trainable_weights,
                                             hub_layer.trainable_weights):
            self.assertAllClose(source_weight, hub_weight)

        # Check computation.
        seq_length = 10
        dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
        input_dict = dict(input_word_ids=dummy_ids,
                          input_mask=dummy_ids,
                          input_type_ids=dummy_ids)
        hub_outputs_dict = hub_layer(input_dict)
        source_outputs_dict = bert_model(input_dict)
        encoder_outputs_dict = pretrainer.encoder_network(
            [dummy_ids, dummy_ids, dummy_ids])
        self.assertEqual(hub_outputs_dict["pooled_output"].shape,
                         (2, hidden_size))
        self.assertEqual(hub_outputs_dict["sequence_output"].shape,
                         (2, seq_length, hidden_size))
        for output_key in ("pooled_output", "sequence_output",
                           "encoder_outputs"):
            self.assertAllClose(source_outputs_dict[output_key],
                                hub_outputs_dict[output_key])
            self.assertAllClose(source_outputs_dict[output_key],
                                encoder_outputs_dict[output_key])

        # The "default" output of BERT as a text representation is pooled_output.
        self.assertAllClose(hub_outputs_dict["pooled_output"],
                            hub_outputs_dict["default"])

        # Test that training=True makes a difference (activates dropout).
        def _dropout_mean_stddev(training, num_runs=20):
            input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
            input_dict = dict(input_word_ids=input_ids,
                              input_mask=np.ones_like(input_ids),
                              input_type_ids=np.zeros_like(input_ids))
            outputs = np.concatenate([
                hub_layer(input_dict, training=training)["pooled_output"]
                for _ in range(num_runs)
            ])
            return np.mean(np.std(outputs, axis=0))

        self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
        self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)

        # Checks sub-object `mlm`.
        self.assertTrue(hasattr(hub_layer.resolved_object, "mlm"))

        self.assertLen(hub_layer.resolved_object.mlm.trainable_variables,
                       len(bert_model_with_mlm.trainable_weights))
        self.assertLen(hub_layer.resolved_object.mlm.trainable_variables,
                       len(pretrainer.trainable_weights))
        for source_weight, hub_weight, pretrainer_weight in zip(
                bert_model_with_mlm.trainable_weights,
                hub_layer.resolved_object.mlm.trainable_variables,
                pretrainer.trainable_weights):
            self.assertAllClose(source_weight, hub_weight)
            self.assertAllClose(source_weight, pretrainer_weight)

        max_predictions_per_seq = 4
        mlm_positions = np.zeros((2, max_predictions_per_seq), dtype=np.int32)
        input_dict = dict(input_word_ids=dummy_ids,
                          input_mask=dummy_ids,
                          input_type_ids=dummy_ids,
                          masked_lm_positions=mlm_positions)
        hub_mlm_outputs_dict = hub_layer.resolved_object.mlm(input_dict)
        source_mlm_outputs_dict = bert_model_with_mlm(input_dict)
        for output_key in ("pooled_output", "sequence_output", "mlm_logits",
                           "encoder_outputs"):
            self.assertAllClose(hub_mlm_outputs_dict[output_key],
                                source_mlm_outputs_dict[output_key])

        pretrainer_mlm_logits_output = pretrainer(input_dict)["mlm_logits"]
        self.assertAllClose(hub_mlm_outputs_dict["mlm_logits"],
                            pretrainer_mlm_logits_output)

        # Test that training=True makes a difference (activates dropout).
        def _dropout_mean_stddev_mlm(training, num_runs=20):
            input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
            mlm_position_ids = np.array([[1, 2, 3, 4]], np.int32)
            input_dict = dict(input_word_ids=input_ids,
                              input_mask=np.ones_like(input_ids),
                              input_type_ids=np.zeros_like(input_ids),
                              masked_lm_positions=mlm_position_ids)
            outputs = np.concatenate([
                hub_layer.resolved_object.mlm(
                    input_dict, training=training)["pooled_output"]
                for _ in range(num_runs)
            ])
            return np.mean(np.std(outputs, axis=0))

        self.assertLess(_dropout_mean_stddev_mlm(training=False), 1e-6)
        self.assertGreater(_dropout_mean_stddev_mlm(training=True), 1e-3)

        # Test propagation of seq_length in shape inference.
        input_word_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                               dtype=tf.int32)
        input_mask = tf.keras.layers.Input(shape=(seq_length, ),
                                           dtype=tf.int32)
        input_type_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                               dtype=tf.int32)
        input_dict = dict(input_word_ids=input_word_ids,
                          input_mask=input_mask,
                          input_type_ids=input_type_ids)
        hub_outputs_dict = hub_layer(input_dict)
        self.assertEqual(hub_outputs_dict["pooled_output"].shape.as_list(),
                         [None, hidden_size])
        self.assertEqual(hub_outputs_dict["sequence_output"].shape.as_list(),
                         [None, seq_length, hidden_size])
Пример #5
0
    def test_export_model(self, use_bert):
        # Create the encoder and export it.
        hidden_size = 16
        num_hidden_layers = 1
        bert_config, encoder_config = _get_bert_config_or_encoder_config(
            use_bert, hidden_size, num_hidden_layers)
        bert_model, encoder = export_tfhub_lib._create_model(
            bert_config=bert_config,
            encoder_config=encoder_config,
            with_mlm=False)
        self.assertEmpty(
            _find_lambda_layers(bert_model),
            "Lambda layers are non-portable since they serialize Python bytecode."
        )
        model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
        checkpoint = tf.train.Checkpoint(encoder=encoder)
        checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
        model_checkpoint_path = tf.train.latest_checkpoint(
            model_checkpoint_dir)

        vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
            self.get_temp_dir(), use_sp_model=not use_bert)
        export_path = os.path.join(self.get_temp_dir(), "hub")
        export_tfhub_lib.export_model(
            export_path=export_path,
            bert_config=bert_config,
            encoder_config=encoder_config,
            model_checkpoint_path=model_checkpoint_path,
            with_mlm=False,
            vocab_file=vocab_file,
            sp_model_file=sp_model_file,
            do_lower_case=True)

        # Restore the exported model.
        hub_layer = hub.KerasLayer(export_path, trainable=True)

        # Check legacy tokenization data.
        if use_bert:
            self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
            self.assertEqual("dummy content",
                             _read_asset(hub_layer.resolved_object.vocab_file))
            self.assertFalse(
                hasattr(hub_layer.resolved_object, "sp_model_file"))
        else:
            self.assertFalse(
                hasattr(hub_layer.resolved_object, "do_lower_case"))
            self.assertFalse(hasattr(hub_layer.resolved_object, "vocab_file"))
            self.assertEqual(
                "dummy content",
                _read_asset(hub_layer.resolved_object.sp_model_file))

        # Check restored weights.
        self.assertEqual(len(bert_model.trainable_weights),
                         len(hub_layer.trainable_weights))
        for source_weight, hub_weight in zip(bert_model.trainable_weights,
                                             hub_layer.trainable_weights):
            self.assertAllClose(source_weight.numpy(), hub_weight.numpy())

        # Check computation.
        seq_length = 10
        dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
        input_dict = dict(input_word_ids=dummy_ids,
                          input_mask=dummy_ids,
                          input_type_ids=dummy_ids)
        hub_output = hub_layer(input_dict)
        source_output = bert_model(input_dict)
        encoder_output = encoder(input_dict)
        self.assertEqual(hub_output["pooled_output"].shape, (2, hidden_size))
        self.assertEqual(hub_output["sequence_output"].shape,
                         (2, seq_length, hidden_size))
        self.assertLen(hub_output["encoder_outputs"], num_hidden_layers)

        for key in ("pooled_output", "sequence_output", "encoder_outputs"):
            self.assertAllClose(source_output[key], hub_output[key])
            self.assertAllClose(source_output[key], encoder_output[key])

        # The "default" output of BERT as a text representation is pooled_output.
        self.assertAllClose(hub_output["pooled_output"], hub_output["default"])

        # Test that training=True makes a difference (activates dropout).
        def _dropout_mean_stddev(training, num_runs=20):
            input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
            input_dict = dict(input_word_ids=input_ids,
                              input_mask=np.ones_like(input_ids),
                              input_type_ids=np.zeros_like(input_ids))
            outputs = np.concatenate([
                hub_layer(input_dict, training=training)["pooled_output"]
                for _ in range(num_runs)
            ])
            return np.mean(np.std(outputs, axis=0))

        self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
        self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)

        # Test propagation of seq_length in shape inference.
        input_word_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                               dtype=tf.int32)
        input_mask = tf.keras.layers.Input(shape=(seq_length, ),
                                           dtype=tf.int32)
        input_type_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                               dtype=tf.int32)
        input_dict = dict(input_word_ids=input_word_ids,
                          input_mask=input_mask,
                          input_type_ids=input_type_ids)
        output_dict = hub_layer(input_dict)
        pooled_output = output_dict["pooled_output"]
        sequence_output = output_dict["sequence_output"]
        encoder_outputs = output_dict["encoder_outputs"]

        self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
        self.assertEqual(sequence_output.shape.as_list(),
                         [None, seq_length, hidden_size])
        self.assertLen(encoder_outputs, num_hidden_layers)
Пример #6
0
    def test_preprocessing_for_mlm(self, use_bert):
        """Combines both SavedModel types and TF.text helpers for MLM."""
        # Create the preprocessing SavedModel with a [MASK] token.
        non_special_tokens = [
            "hello", "world", "nice", "movie", "great", "actors", "quick",
            "fox", "lazy", "dog"
        ]
        preprocess = tf.saved_model.load(
            self._do_export(
                non_special_tokens,
                do_lower_case=True,
                tokenize_with_offsets=use_bert,  # TODO(b/181866850): drop this.
                experimental_disable_assert=
                True,  # TODO(b/175369555): drop this.
                add_mask_token=True,
                use_sp_model=not use_bert))
        vocab_size = len(non_special_tokens) + (5 if use_bert else 7)

        # Create the encoder SavedModel with an .mlm subobject.
        hidden_size = 16
        num_hidden_layers = 2
        bert_config, encoder_config = _get_bert_config_or_encoder_config(
            use_bert, hidden_size, num_hidden_layers, vocab_size)
        _, pretrainer = export_tfhub_lib._create_model(
            bert_config=bert_config,
            encoder_config=encoder_config,
            with_mlm=True)
        model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
        checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
        checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
        model_checkpoint_path = tf.train.latest_checkpoint(
            model_checkpoint_dir)
        vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(  # Not used below.
            self.get_temp_dir(), use_sp_model=not use_bert)
        encoder_export_path = os.path.join(self.get_temp_dir(),
                                           "encoder_export")
        export_tfhub_lib.export_model(
            export_path=encoder_export_path,
            bert_config=bert_config,
            encoder_config=encoder_config,
            model_checkpoint_path=model_checkpoint_path,
            with_mlm=True,
            vocab_file=vocab_file,
            sp_model_file=sp_model_file,
            do_lower_case=True)
        encoder = tf.saved_model.load(encoder_export_path)

        # Get special tokens from the vocab (and vocab size).
        special_tokens_dict = preprocess.tokenize.get_special_tokens_dict()
        self.assertEqual(int(special_tokens_dict["vocab_size"]), vocab_size)
        padding_id = int(special_tokens_dict["padding_id"])
        self.assertEqual(padding_id, 0)
        start_of_sequence_id = int(special_tokens_dict["start_of_sequence_id"])
        self.assertEqual(start_of_sequence_id, 2)
        end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
        self.assertEqual(end_of_segment_id, 3)
        mask_id = int(special_tokens_dict["mask_id"])
        self.assertEqual(mask_id, 4)

        # A batch of 3 segment pairs.
        raw_segments = [
            tf.constant(["hello", "nice movie", "quick fox"]),
            tf.constant(["world", "great actors", "lazy dog"])
        ]
        batch_size = 3

        # Misc hyperparameters.
        seq_length = 10
        max_selections_per_seq = 2

        # Tokenize inputs.
        tokenized_segments = [preprocess.tokenize(s) for s in raw_segments]
        # Trim inputs to eventually fit seq_lentgh.
        num_special_tokens = len(raw_segments) + 1
        trimmed_segments = text.WaterfallTrimmer(
            seq_length - num_special_tokens).trim(tokenized_segments)
        # Combine input segments into one input sequence.
        input_ids, segment_ids = text.combine_segments(
            trimmed_segments,
            start_of_sequence_id=start_of_sequence_id,
            end_of_segment_id=end_of_segment_id)
        # Apply random masking controlled by policy objects.
        (masked_input_ids, masked_lm_positions,
         masked_ids) = text.mask_language_model(
             input_ids=input_ids,
             item_selector=text.RandomItemSelector(
                 max_selections_per_seq,
                 selection_rate=0.5,  # Adjusted for the short test examples.
                 unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
             mask_values_chooser=text.MaskValuesChooser(
                 vocab_size=vocab_size,
                 mask_token=mask_id,
                 # Always put [MASK] to have a predictable result.
                 mask_token_rate=1.0,
                 random_token_rate=0.0))
        # Pad to fixed-length Transformer encoder inputs.
        input_word_ids, _ = text.pad_model_inputs(masked_input_ids,
                                                  seq_length,
                                                  pad_value=padding_id)
        input_type_ids, input_mask = text.pad_model_inputs(segment_ids,
                                                           seq_length,
                                                           pad_value=0)
        masked_lm_positions, _ = text.pad_model_inputs(masked_lm_positions,
                                                       max_selections_per_seq,
                                                       pad_value=0)
        masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
        num_predictions = int(tf.shape(masked_lm_positions)[1])

        # Test transformer inputs.
        self.assertEqual(num_predictions, max_selections_per_seq)
        expected_word_ids = np.array([
            # [CLS] hello [SEP] world [SEP]
            [2, 5, 3, 6, 3, 0, 0, 0, 0, 0],
            # [CLS] nice movie [SEP] great actors [SEP]
            [2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
            # [CLS] brown fox [SEP] lazy dog [SEP]
            [2, 11, 12, 3, 13, 14, 3, 0, 0, 0]
        ])
        for i in range(batch_size):
            for j in range(num_predictions):
                k = int(masked_lm_positions[i, j])
                if k != 0:
                    expected_word_ids[i, k] = 4  # [MASK]
        self.assertAllEqual(input_word_ids, expected_word_ids)

        # Call the MLM head of the Transformer encoder.
        mlm_inputs = dict(
            input_word_ids=input_word_ids,
            input_mask=input_mask,
            input_type_ids=input_type_ids,
            masked_lm_positions=masked_lm_positions,
        )
        mlm_outputs = encoder.mlm(mlm_inputs)
        self.assertEqual(mlm_outputs["pooled_output"].shape,
                         (batch_size, hidden_size))
        self.assertEqual(mlm_outputs["sequence_output"].shape,
                         (batch_size, seq_length, hidden_size))
        self.assertEqual(mlm_outputs["mlm_logits"].shape,
                         (batch_size, num_predictions, vocab_size))
        self.assertLen(mlm_outputs["encoder_outputs"], num_hidden_layers)

        # A real trainer would now compute the loss of mlm_logits
        # trying to predict the masked_ids.
        del masked_ids  # Unused.