def get_config(self): return LxmertConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_attention_heads=self.num_attention_heads, num_labels=self.num_labels, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, layer_norm_eps=self.layer_norm_eps, pad_token_id=self.pad_token_id, num_qa_labels=self.num_qa_labels, num_object_labels=self.num_object_labels, num_attr_labels=self.num_attr_labels, l_layers=self.l_layers, x_layers=self.x_layers, r_layers=self.r_layers, visual_feat_dim=self.visual_feat_dim, visual_pos_dim=self.visual_pos_dim, visual_loss_normalizer=self.visual_loss_normalizer, task_matched=self.task_matched, task_mask_lm=self.task_mask_lm, task_obj_predict=self.task_obj_predict, task_qa=self.task_qa, visual_obj_loss=self.visual_obj_loss, visual_attr_loss=self.visual_attr_loss, visual_feat_loss=self.visual_feat_loss, output_attentions=self.output_attentions, output_hidden_states=self.output_hidden_states, )
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = LxmertConfig.from_json_file(config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = LxmertForPreTraining(config) # Load weights from tf checkpoint load_tf_weights_in_lxmert(model, config, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path)
def prepare_config_and_inputs(self): output_attentions = self.output_attentions input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_size=self.vocab_size) visual_feats = torch.rand(self.batch_size, self.num_visual_features, self.visual_feat_dim, device=torch_device) bounding_boxes = torch.rand(self.batch_size, self.num_visual_features, 4, device=torch_device) input_mask = None if self.use_lang_mask: input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) token_type_ids = None if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) obj_labels = None if self.task_obj_predict: obj_labels = {} if self.visual_attr_loss and self.task_obj_predict: obj_labels["attr"] = ( ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels), ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels), ) if self.visual_feat_loss and self.task_obj_predict: obj_labels["feat"] = ( ids_tensor([ self.batch_size, self.num_visual_features, self.visual_feat_dim ], self.num_visual_features), ids_tensor([self.batch_size, self.num_visual_features], self.num_visual_features), ) if self.visual_obj_loss and self.task_obj_predict: obj_labels["obj"] = ( ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels), ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels), ) ans = None if self.task_qa: ans = ids_tensor([self.batch_size], self.num_qa_labels) masked_lm_labels = None if self.task_mask_lm: masked_lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) matched_label = None if self.task_matched: matched_label = ids_tensor([self.batch_size], self.num_labels) config = LxmertConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_attention_heads=self.num_attention_heads, num_labels=self.num_labels, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, layer_norm_eps=self.layer_norm_eps, pad_token_id=self.pad_token_id, num_qa_labels=self.num_qa_labels, num_object_labels=self.num_object_labels, num_attr_labels=self.num_attr_labels, l_layers=self.l_layers, x_layers=self.x_layers, r_layers=self.r_layers, visual_feat_dim=self.visual_feat_dim, visual_pos_dim=self.visual_pos_dim, visual_loss_normalizer=self.visual_loss_normalizer, task_matched=self.task_matched, task_mask_lm=self.task_mask_lm, task_obj_predict=self.task_obj_predict, task_qa=self.task_qa, visual_obj_loss=self.visual_obj_loss, visual_attr_loss=self.visual_attr_loss, visual_feat_loss=self.visual_feat_loss, output_attentions=self.output_attentions, output_hidden_states=self.output_hidden_states, ) return ( config, input_ids, visual_feats, bounding_boxes, token_type_ids, input_mask, obj_labels, masked_lm_labels, matched_label, ans, output_attentions, )