def create_and_check_model( self, config, input_ids, input_mask, token_type_ids, sequence_labels, token_labels, labels, numeric_values, numeric_values_scale, float_answer, aggregation_labels, ): model = TapasModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def convert_tf_checkpoint_to_pytorch(task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path): # Initialise PyTorch model. # If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of # TapasConfig to False. # initialize configuration from json file config = TapasConfig.from_json_file(tapas_config_file) # set absolute/relative position embeddings parameter config.reset_position_index_per_cell = reset_position_index_per_cell # set remaining parameters of TapasConfig as well as the model based on the task if task == "SQA": model = TapasForQuestionAnswering(config=config) elif task == "WTQ": # run_task_main.py hparams config.num_aggregation_labels = 4 config.use_answer_as_supervision = True # hparam_utils.py hparams config.answer_loss_cutoff = 0.664694 config.cell_selection_preference = 0.207951 config.huber_loss_delta = 0.121194 config.init_cell_selection_weights_to_zero = True config.select_one_column = True config.allow_empty_column_selection = False config.temperature = 0.0352513 model = TapasForQuestionAnswering(config=config) elif task == "WIKISQL_SUPERVISED": # run_task_main.py hparams config.num_aggregation_labels = 4 config.use_answer_as_supervision = False # hparam_utils.py hparams config.answer_loss_cutoff = 36.4519 config.cell_selection_preference = 0.903421 config.huber_loss_delta = 222.088 config.init_cell_selection_weights_to_zero = True config.select_one_column = True config.allow_empty_column_selection = True config.temperature = 0.763141 model = TapasForQuestionAnswering(config=config) elif task == "TABFACT": model = TapasForSequenceClassification(config=config) elif task == "MLM": model = TapasForMaskedLM(config=config) elif task == "INTERMEDIATE_PRETRAINING": model = TapasModel(config=config) else: raise ValueError(f"Task {task} not supported.") print(f"Building PyTorch model from configuration: {config}") # Load weights from tf checkpoint load_tf_weights_in_tapas(model, config, tf_checkpoint_path) # Save pytorch-model (weights and configuration) print(f"Save PyTorch model to {pytorch_dump_path}") model.save_pretrained(pytorch_dump_path) # Save tokenizer files print(f"Save tokenizer files to {pytorch_dump_path}") tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512) tokenizer.save_pretrained(pytorch_dump_path) print("Used relative position embeddings:", model.config.reset_position_index_per_cell)