def test_inference_classification_head(self): # note that google/tapas-base-finetuned-tabfact should correspond to tapas_tabfact_inter_masklm_base_reset model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact").to(torch_device) tokenizer = self.default_tokenizer table, queries = prepare_tapas_single_inputs_for_inference() inputs = tokenizer(table=table, queries=queries, padding="longest", return_tensors="pt") inputs = {k: v.to(torch_device) for k, v in inputs.items()} outputs = model(**inputs) # test the classification logits logits = outputs.logits expected_shape = torch.Size((1, 2)) self.assertEqual(logits.shape, expected_shape) expected_tensor = torch.tensor( [[0.795137286, 9.5572]], device=torch_device ) # Note that the PyTorch model outputs [[0.8057, 9.5281]] self.assertTrue(torch.allclose(outputs.logits, expected_tensor, atol=TOLERANCE))
def create_and_check_for_sequence_classification( self, config, input_ids, input_mask, token_type_ids, sequence_labels, token_labels, labels, numeric_values, numeric_values_scale, float_answer, aggregation_labels, ): config.num_labels = self.num_labels model = TapasForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def convert_tf_checkpoint_to_pytorch(task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path): # Initialise PyTorch model. # If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of # TapasConfig to False. # initialize configuration from json file config = TapasConfig.from_json_file(tapas_config_file) # set absolute/relative position embeddings parameter config.reset_position_index_per_cell = reset_position_index_per_cell # set remaining parameters of TapasConfig as well as the model based on the task if task == "SQA": model = TapasForQuestionAnswering(config=config) elif task == "WTQ": # run_task_main.py hparams config.num_aggregation_labels = 4 config.use_answer_as_supervision = True # hparam_utils.py hparams config.answer_loss_cutoff = 0.664694 config.cell_selection_preference = 0.207951 config.huber_loss_delta = 0.121194 config.init_cell_selection_weights_to_zero = True config.select_one_column = True config.allow_empty_column_selection = False config.temperature = 0.0352513 model = TapasForQuestionAnswering(config=config) elif task == "WIKISQL_SUPERVISED": # run_task_main.py hparams config.num_aggregation_labels = 4 config.use_answer_as_supervision = False # hparam_utils.py hparams config.answer_loss_cutoff = 36.4519 config.cell_selection_preference = 0.903421 config.huber_loss_delta = 222.088 config.init_cell_selection_weights_to_zero = True config.select_one_column = True config.allow_empty_column_selection = True config.temperature = 0.763141 model = TapasForQuestionAnswering(config=config) elif task == "TABFACT": model = TapasForSequenceClassification(config=config) elif task == "MLM": model = TapasForMaskedLM(config=config) elif task == "INTERMEDIATE_PRETRAINING": model = TapasModel(config=config) else: raise ValueError(f"Task {task} not supported.") print(f"Building PyTorch model from configuration: {config}") # Load weights from tf checkpoint load_tf_weights_in_tapas(model, config, tf_checkpoint_path) # Save pytorch-model (weights and configuration) print(f"Save PyTorch model to {pytorch_dump_path}") model.save_pretrained(pytorch_dump_path) # Save tokenizer files print(f"Save tokenizer files to {pytorch_dump_path}") tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512) tokenizer.save_pretrained(pytorch_dump_path) print("Used relative position embeddings:", model.config.reset_position_index_per_cell)