def create_and_check_model(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = TFRemBertModel(config=config)
        inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}

        inputs = [input_ids, input_mask]
        result = model(inputs)

        result = model(input_ids)

        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
    def create_and_check_model_as_decoder(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        config.add_cross_attention = True

        model = TFRemBertModel(config=config)
        inputs = {
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "token_type_ids": token_type_ids,
            "encoder_hidden_states": encoder_hidden_states,
            "encoder_attention_mask": encoder_attention_mask,
        }
        result = model(inputs)

        inputs = [input_ids, input_mask]
        result = model(inputs,
                       token_type_ids=token_type_ids,
                       encoder_hidden_states=encoder_hidden_states)

        # Also check the case where encoder outputs are not passed
        result = model(input_ids,
                       attention_mask=input_mask,
                       token_type_ids=token_type_ids)

        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length, self.hidden_size))
    def test_inference_model(self):
        model = TFRemBertModel.from_pretrained("google/rembert")

        input_ids = tf.constant([[312, 56498, 313, 2125, 313]])
        segment_ids = tf.constant([[0, 0, 0, 1, 1]])
        output = model(input_ids,
                       token_type_ids=segment_ids,
                       output_hidden_states=True)

        hidden_size = 1152

        expected_shape = [1, 5, hidden_size]
        self.assertEqual(output["last_hidden_state"].shape, expected_shape)

        expected_implementation = tf.constant([[
            [0.0754, -0.2022, 0.1904],
            [-0.3354, -0.3692, -0.4791],
            [-0.2314, -0.6729, -0.0749],
            [-0.0396, -0.3105, -0.4234],
            [-0.1571, -0.0525, 0.5353],
        ]])
        tf.debugging.assert_near(output["last_hidden_state"][:, :, :3],
                                 expected_implementation,
                                 atol=1e-4)
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = TFRemBertModel(config, name="encoder")
     decoder_model = TFRemBertForCausalLM(decoder_config, name="decoder")
     return encoder_model, decoder_model
 def test_model_from_pretrained(self):
     model = TFRemBertModel.from_pretrained("google/rembert")
     self.assertIsNotNone(model)