示例#1
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = LxmertConfig.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = LxmertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_lxmert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
示例#2
0
    def __init__(self, dummy_config):
        super(LXMERT, self).__init__(dummy_config)
        
        frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
        # self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg)
        self.backbone, self.roi_heads = build_image_encoder()
        self.lxmert_vqa = LxmertForPreTraining.from_pretrained("unc-nlp/lxmert-base-uncased")
        # self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased")
        self.tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
        self.image_preprocess = Preprocess(frcnn_cfg)
        
        hid_dim = self.lxmert_vqa.config.hidden_size
        # transform = BertPredictionHeadTransform(self.config.NETWORK.VLBERT)

        self.logit_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim),
            GELU(),
            BertLayerNorm(hid_dim),
            nn.Dropout(self.config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            nn.Linear(hid_dim, self.config.NETWORK.CLASSIFIER_CLASS),
        )
    def resize_lxmert_num_qa_labels(
        self,
        config,
        input_ids,
        visual_feats,
        bounding_boxes,
        token_type_ids,
        input_mask,
        obj_labels,
        masked_lm_labels,
        matched_label,
        ans,
        output_attentions,
    ):

        start_labels = config.num_qa_labels
        num_large_labels = config.num_qa_labels * 2
        num_small_labels = int(config.num_qa_labels * 2)
        less_labels_ans = ids_tensor([self.batch_size], num_small_labels)
        more_labels_ans = ids_tensor([self.batch_size], num_large_labels)
        model_pretrain = LxmertForPreTraining(config=config).to(torch_device)
        model_qa = LxmertForQuestionAnswering(config=config).to(torch_device)
        config.num_labels = num_small_labels
        end_labels = config.num_labels

        result_pretrain = model_pretrain(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            ans=ans,
        )

        result_qa = model_qa(
            input_ids,
            visual_feats,
            bounding_boxes,
            labels=ans,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
        )

        model_pretrain.resize_num_qa_labels(num_small_labels)
        model_qa.resize_num_qa_labels(num_small_labels)

        result_pretrain_less = model_pretrain(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            ans=less_labels_ans,
        )

        result_qa_less = model_qa(
            input_ids,
            visual_feats,
            bounding_boxes,
            labels=less_labels_ans,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
        )

        model_pretrain.resize_num_qa_labels(num_large_labels)
        model_qa.resize_num_qa_labels(num_large_labels)

        result_pretrain_more = model_pretrain(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            ans=more_labels_ans,
        )

        result_qa_more = model_qa(
            input_ids,
            visual_feats,
            bounding_boxes,
            labels=more_labels_ans,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
        )

        model_qa_labels = model_qa.num_qa_labels

        self.parent.assertNotEqual(start_labels, end_labels)
        self.parent.assertNotEqual(model_qa_labels, start_labels)
        self.parent.assertEqual(result_qa.question_answering_score.shape,
                                (self.batch_size, start_labels))
        self.parent.assertEqual(result_pretrain.question_answering_score.shape,
                                (self.batch_size, start_labels))
        self.parent.assertEqual(result_qa_less.question_answering_score.shape,
                                (self.batch_size, num_small_labels))
        self.parent.assertEqual(
            result_pretrain_less.question_answering_score.shape,
            (self.batch_size, num_small_labels))
        self.parent.assertEqual(result_qa_more.question_answering_score.shape,
                                (self.batch_size, num_large_labels))
        self.parent.assertEqual(
            result_pretrain_more.question_answering_score.shape,
            (self.batch_size, num_large_labels))
    def create_and_check_lxmert_for_pretraining(
        self,
        config,
        input_ids,
        visual_feats,
        bounding_boxes,
        token_type_ids,
        input_mask,
        obj_labels,
        masked_lm_labels,
        matched_label,
        ans,
        output_attentions,
    ):
        model = LxmertForPreTraining(config=config)
        model.to(torch_device)
        model.eval()
        result = model(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            masked_lm_labels=masked_lm_labels,
            obj_labels=obj_labels,
            matched_label=matched_label,
            ans=ans,
            output_attentions=output_attentions,
        )
        result = model(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            masked_lm_labels=masked_lm_labels,
            output_attentions=not output_attentions,
            return_dict=False,
        )
        result = model(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            masked_lm_labels=masked_lm_labels,
        )
        result = model(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            obj_labels=obj_labels,
        )
        result = model(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            matched_label=matched_label,
        )
        result = model(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            ans=ans,
        )
        result = model(
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids=token_type_ids,
            attention_mask=input_mask,
            masked_lm_labels=masked_lm_labels,
            obj_labels=obj_labels,
            matched_label=matched_label,
            ans=ans,
            output_attentions=not output_attentions,
        )

        self.parent.assertEqual(
            result.prediction_logits.shape,
            (self.batch_size, self.seq_length, self.vocab_size))