def create_and_check_lxmert_model( self, config, input_ids, visual_feats, bounding_boxes, token_type_ids, input_mask, obj_labels, masked_lm_labels, matched_label, ans, output_attentions, ): model = TFLxmertModel(config=config) result = model( input_ids, visual_feats, bounding_boxes, token_type_ids=token_type_ids, attention_mask=input_mask, output_attentions=output_attentions, ) result = model( input_ids, visual_feats, bounding_boxes, token_type_ids=token_type_ids, attention_mask=input_mask, output_attentions=not output_attentions, ) result = model(input_ids, visual_feats, bounding_boxes, return_dict=False) result = model(input_ids, visual_feats, bounding_boxes, return_dict=True) self.parent.assertEqual( result.language_output.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual( result.vision_output.shape, (self.batch_size, self.num_visual_features, self.hidden_size)) self.parent.assertEqual(result.pooled_output.shape, (self.batch_size, self.hidden_size))
def test_inference_masked_lm(self): model = TFLxmertModel.from_pretrained("unc-nlp/lxmert-base-uncased") input_ids = tf.constant([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]]) num_visual_features = 10 _, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, model.config.visual_feat_dim) _, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4) visual_feats = tf.convert_to_tensor(visual_feats, dtype=tf.float32) visual_pos = tf.convert_to_tensor(visual_pos, dtype=tf.float32) output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0] expected_shape = [1, 11, 768] self.assertEqual(expected_shape, output.shape) expected_slice = tf.constant( [ [ [0.24170142, -0.98075, 0.14797261], [1.2540525, -0.83198136, 0.5112344], [1.4070463, -1.1051831, 0.6990401], ] ] ) tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
def test_model_from_pretrained(self): for model_name in ["unc-nlp/lxmert-base-uncased"]: model = TFLxmertModel.from_pretrained(model_name) self.assertIsNotNone(model)