def test_inference_no_head(self): # ideally we want to test this with the weights of tapas_inter_masklm_base_reset, # but since it's not straightforward to do this with the TF 1 implementation, we test it with # the weights of the WTQ base model (i.e. tapas_wtq_wikisql_sqa_inter_masklm_base_reset) model = TapasModel.from_pretrained( "google/tapas-base-finetuned-wtq").to(torch_device) tokenizer = self.default_tokenizer table, queries = prepare_tapas_single_inputs_for_inference() inputs = tokenizer(table=table, queries=queries, return_tensors="pt") inputs = {k: v.to(torch_device) for k, v in inputs.items()} outputs = model(**inputs) # test the sequence output expected_slice = torch.tensor( [[ [-0.141581565, -0.599805772, 0.747186482], [-0.143664181, -0.602008104, 0.749218345], [-0.15169853, -0.603363097, 0.741370678], ]], device=torch_device, ) self.assertTrue( torch.allclose(outputs.last_hidden_state[:, :3, :3], expected_slice, atol=0.0005)) # test the pooled output expected_slice = torch.tensor( [[0.987518311, -0.970520139, -0.994303405]], device=torch_device) self.assertTrue( torch.allclose(outputs.pooler_output[:, :3], expected_slice, atol=0.0005))
def create_and_check_model( self, config, input_ids, input_mask, token_type_ids, sequence_labels, token_labels, labels, numeric_values, numeric_values_scale, float_answer, aggregation_labels, ): model = TapasModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def convert_tf_checkpoint_to_pytorch(task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path): # Initialise PyTorch model. # If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of # TapasConfig to False. # initialize configuration from json file config = TapasConfig.from_json_file(tapas_config_file) # set absolute/relative position embeddings parameter config.reset_position_index_per_cell = reset_position_index_per_cell # set remaining parameters of TapasConfig as well as the model based on the task if task == "SQA": model = TapasForQuestionAnswering(config=config) elif task == "WTQ": # run_task_main.py hparams config.num_aggregation_labels = 4 config.use_answer_as_supervision = True # hparam_utils.py hparams config.answer_loss_cutoff = 0.664694 config.cell_selection_preference = 0.207951 config.huber_loss_delta = 0.121194 config.init_cell_selection_weights_to_zero = True config.select_one_column = True config.allow_empty_column_selection = False config.temperature = 0.0352513 model = TapasForQuestionAnswering(config=config) elif task == "WIKISQL_SUPERVISED": # run_task_main.py hparams config.num_aggregation_labels = 4 config.use_answer_as_supervision = False # hparam_utils.py hparams config.answer_loss_cutoff = 36.4519 config.cell_selection_preference = 0.903421 config.huber_loss_delta = 222.088 config.init_cell_selection_weights_to_zero = True config.select_one_column = True config.allow_empty_column_selection = True config.temperature = 0.763141 model = TapasForQuestionAnswering(config=config) elif task == "TABFACT": model = TapasForSequenceClassification(config=config) elif task == "MLM": model = TapasForMaskedLM(config=config) elif task == "INTERMEDIATE_PRETRAINING": model = TapasModel(config=config) else: raise ValueError(f"Task {task} not supported.") print(f"Building PyTorch model from configuration: {config}") # Load weights from tf checkpoint load_tf_weights_in_tapas(model, config, tf_checkpoint_path) # Save pytorch-model (weights and configuration) print(f"Save PyTorch model to {pytorch_dump_path}") model.save_pretrained(pytorch_dump_path) # Save tokenizer files print(f"Save tokenizer files to {pytorch_dump_path}") tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512) tokenizer.save_pretrained(pytorch_dump_path) print("Used relative position embeddings:", model.config.reset_position_index_per_cell)