def main(tf_checkpoint_path, bert_config_file, save_dir_bert_part, save_dir): # Initialise PyTorch model config = BertConfig.from_json_file(bert_config_file) logger.info("Building PyTorch model from configuration: {}".format( str(config))) model = BertForPreTraining(config) # Load weights from tf checkpoint extra_list = load_tf_weights_in_bert(model, config, tf_checkpoint_path) # Save the Bert part Path(save_dir_bert_part).mkdir(exist_ok=True) model.save_pretrained(save_dir_bert_part) # Reload nq_model = NQBert.from_pretrained(save_dir_bert_part) # Check if the Bert parameters match tf_model_parameters = list(model.parameters()) all((torch.equal(p, tf_model_parameters[i]) for i, p in enumerate(nq_model.bert.parameters()))) # Copy over parameters of final layers tf_model_final_layer_params = {x[0]: x[1] for x in extra_list} nq_model_final_layer_params = OrderedDict({ 'span_outputs.weight': torch.tensor(tf_model_final_layer_params['output_weights']), 'span_outputs.bias': torch.tensor(tf_model_final_layer_params['output_bias']), 'type_output.weight': torch.tensor( tf_model_final_layer_params['answer_type_output_weights']), 'type_output.bias': torch.tensor(tf_model_final_layer_params['answer_type_output_bias']), }) nq_model.load_state_dict(nq_model_final_layer_params, strict=False) # Check if final parameters match all([ torch.equal(x[0], x[1]) for x in zip( list(nq_model.parameters())[-4:], list(nq_model_final_layer_params.values())) ]) # Save whole model Path(save_dir).mkdir(exist_ok=True) nq_model.save_pretrained(save_dir)
def convert_multibert_checkpoint_to_pytorch(tf_checkpoint_path, config_path, save_path): tf_path = os.path.abspath(tf_checkpoint_path) logger.info(f"Converting TensorFlow checkpoint from {tf_path}") # Load weights from TF model init_vars = tf.train.list_variables(tf_path) names = [] arrays = [] config = BertConfig.from_pretrained(config_path) model = BertForPreTraining(config) layer_nums = [] for full_name, shape in init_vars: array = tf.train.load_variable(tf_path, full_name) names.append(full_name) split_names = full_name.split("/") for name in split_names: if name.startswith("layer_"): layer_nums.append(int(name.split("_")[-1])) arrays.append(array) logger.info(f"Read a total of {len(arrays):,} layers") name_to_array = dict(zip(names, arrays)) # Check that number of layers match assert config.num_hidden_layers == len(list(set(layer_nums))) state_dict = model.state_dict() # Need to do this explicitly as it is a buffer position_ids = state_dict["bert.embeddings.position_ids"] new_state_dict = {"bert.embeddings.position_ids": position_ids} # Encoder Layers for weight_name in names: pt_weight_name = weight_name.replace("kernel", "weight").replace("gamma", "weight").replace("beta", "bias") name_split = pt_weight_name.split("/") for name_idx, name in enumerate(name_split): if name.startswith("layer_"): name_split[name_idx] = name.replace("_", ".") if name_split[-1].endswith("embeddings"): name_split.append("weight") if name_split[0] == "cls": if name_split[-1] == "output_bias": name_split[-1] = "bias" if name_split[-1] == "output_weights": name_split[-1] = "weight" if name_split[-1] == "weight" and name_split[-2] == "dense": name_to_array[weight_name] = name_to_array[weight_name].T pt_weight_name = ".".join(name_split) new_state_dict[pt_weight_name] = torch.from_numpy(name_to_array[weight_name]) new_state_dict["cls.predictions.decoder.weight"] = new_state_dict["bert.embeddings.word_embeddings.weight"].clone() new_state_dict["cls.predictions.decoder.bias"] = new_state_dict["cls.predictions.bias"].clone().T # Load State Dict model.load_state_dict(new_state_dict) # Save PreTrained logger.info(f"Saving pretrained model to {save_path}") model.save_pretrained(save_path) return model
optimizer = AdamW(model.parameters(), lr=2e-5) model.train() train_losses = [] for i in range(1, MAX_STEPS + 1): optimizer.zero_grad() sent_pairs = create_sent_pairs(sents_list, batch_size=BATCH_SIZE) encoded = encode_sent_pairs(sent_pairs) res = model( encoded["input_ids"].to(device), token_type_ids=None, attention_mask=encoded["attention_mask"].to(device), labels=encoded["labels"].to(device), next_sentence_label=encoded["next_sentence_label"].to(device), ) loss = res.loss if i % 100 == 0: logger.info("training step {}, loss {}".format(i, loss)) train_losses.append((i, loss.item())) df_train_loss = pd.DataFrame(train_losses, columns=["step", "train_loss"]) df_train_loss.to_csv(output_base_dir / "train_loss.csv", index=False) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if i == 1 or i % 50000 == 0: save_dir = output_base_dir / "step{}".format(i) model.save_pretrained(save_directory=save_dir)