示例#1
0
 def get_config(self):
     return LxmertConfig(
         vocab_size=self.vocab_size,
         hidden_size=self.hidden_size,
         num_attention_heads=self.num_attention_heads,
         num_labels=self.num_labels,
         intermediate_size=self.intermediate_size,
         hidden_act=self.hidden_act,
         hidden_dropout_prob=self.hidden_dropout_prob,
         attention_probs_dropout_prob=self.attention_probs_dropout_prob,
         max_position_embeddings=self.max_position_embeddings,
         type_vocab_size=self.type_vocab_size,
         initializer_range=self.initializer_range,
         layer_norm_eps=self.layer_norm_eps,
         pad_token_id=self.pad_token_id,
         num_qa_labels=self.num_qa_labels,
         num_object_labels=self.num_object_labels,
         num_attr_labels=self.num_attr_labels,
         l_layers=self.l_layers,
         x_layers=self.x_layers,
         r_layers=self.r_layers,
         visual_feat_dim=self.visual_feat_dim,
         visual_pos_dim=self.visual_pos_dim,
         visual_loss_normalizer=self.visual_loss_normalizer,
         task_matched=self.task_matched,
         task_mask_lm=self.task_mask_lm,
         task_obj_predict=self.task_obj_predict,
         task_qa=self.task_qa,
         visual_obj_loss=self.visual_obj_loss,
         visual_attr_loss=self.visual_attr_loss,
         visual_feat_loss=self.visual_feat_loss,
         output_attentions=self.output_attentions,
         output_hidden_states=self.output_hidden_states,
     )
示例#2
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = LxmertConfig.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = LxmertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_lxmert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
    def prepare_config_and_inputs(self):

        output_attentions = self.output_attentions
        input_ids = ids_tensor([self.batch_size, self.seq_length],
                               vocab_size=self.vocab_size)
        visual_feats = torch.rand(self.batch_size,
                                  self.num_visual_features,
                                  self.visual_feat_dim,
                                  device=torch_device)
        bounding_boxes = torch.rand(self.batch_size,
                                    self.num_visual_features,
                                    4,
                                    device=torch_device)

        input_mask = None
        if self.use_lang_mask:
            input_mask = ids_tensor([self.batch_size, self.seq_length],
                                    vocab_size=2)
        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length],
                                        self.type_vocab_size)
        obj_labels = None
        if self.task_obj_predict:
            obj_labels = {}
        if self.visual_attr_loss and self.task_obj_predict:
            obj_labels["attr"] = (
                ids_tensor([self.batch_size, self.num_visual_features],
                           self.num_attr_labels),
                ids_tensor([self.batch_size, self.num_visual_features],
                           self.num_attr_labels),
            )
        if self.visual_feat_loss and self.task_obj_predict:
            obj_labels["feat"] = (
                ids_tensor([
                    self.batch_size, self.num_visual_features,
                    self.visual_feat_dim
                ], self.num_visual_features),
                ids_tensor([self.batch_size, self.num_visual_features],
                           self.num_visual_features),
            )
        if self.visual_obj_loss and self.task_obj_predict:
            obj_labels["obj"] = (
                ids_tensor([self.batch_size, self.num_visual_features],
                           self.num_object_labels),
                ids_tensor([self.batch_size, self.num_visual_features],
                           self.num_object_labels),
            )
        ans = None
        if self.task_qa:
            ans = ids_tensor([self.batch_size], self.num_qa_labels)
        masked_lm_labels = None
        if self.task_mask_lm:
            masked_lm_labels = ids_tensor([self.batch_size, self.seq_length],
                                          self.vocab_size)
        matched_label = None
        if self.task_matched:
            matched_label = ids_tensor([self.batch_size], self.num_labels)

        config = LxmertConfig(
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_attention_heads=self.num_attention_heads,
            num_labels=self.num_labels,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            initializer_range=self.initializer_range,
            layer_norm_eps=self.layer_norm_eps,
            pad_token_id=self.pad_token_id,
            num_qa_labels=self.num_qa_labels,
            num_object_labels=self.num_object_labels,
            num_attr_labels=self.num_attr_labels,
            l_layers=self.l_layers,
            x_layers=self.x_layers,
            r_layers=self.r_layers,
            visual_feat_dim=self.visual_feat_dim,
            visual_pos_dim=self.visual_pos_dim,
            visual_loss_normalizer=self.visual_loss_normalizer,
            task_matched=self.task_matched,
            task_mask_lm=self.task_mask_lm,
            task_obj_predict=self.task_obj_predict,
            task_qa=self.task_qa,
            visual_obj_loss=self.visual_obj_loss,
            visual_attr_loss=self.visual_attr_loss,
            visual_feat_loss=self.visual_feat_loss,
            output_attentions=self.output_attentions,
            output_hidden_states=self.output_hidden_states,
        )

        return (
            config,
            input_ids,
            visual_feats,
            bounding_boxes,
            token_type_ids,
            input_mask,
            obj_labels,
            masked_lm_labels,
            matched_label,
            ans,
            output_attentions,
        )