예제 #1
0
def main(tf_checkpoint_path, bert_config_file, save_dir_bert_part, save_dir):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    logger.info("Building PyTorch model from configuration: {}".format(
        str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    extra_list = load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save the Bert part
    Path(save_dir_bert_part).mkdir(exist_ok=True)
    model.save_pretrained(save_dir_bert_part)

    # Reload
    nq_model = NQBert.from_pretrained(save_dir_bert_part)

    # Check if the Bert parameters match
    tf_model_parameters = list(model.parameters())
    all((torch.equal(p, tf_model_parameters[i])
         for i, p in enumerate(nq_model.bert.parameters())))

    # Copy over parameters of final layers
    tf_model_final_layer_params = {x[0]: x[1] for x in extra_list}
    nq_model_final_layer_params = OrderedDict({
        'span_outputs.weight':
        torch.tensor(tf_model_final_layer_params['output_weights']),
        'span_outputs.bias':
        torch.tensor(tf_model_final_layer_params['output_bias']),
        'type_output.weight':
        torch.tensor(
            tf_model_final_layer_params['answer_type_output_weights']),
        'type_output.bias':
        torch.tensor(tf_model_final_layer_params['answer_type_output_bias']),
    })

    nq_model.load_state_dict(nq_model_final_layer_params, strict=False)

    # Check if final parameters match
    all([
        torch.equal(x[0], x[1]) for x in zip(
            list(nq_model.parameters())[-4:],
            list(nq_model_final_layer_params.values()))
    ])

    # Save whole model
    Path(save_dir).mkdir(exist_ok=True)
    nq_model.save_pretrained(save_dir)
def convert_multibert_checkpoint_to_pytorch(tf_checkpoint_path, config_path, save_path):
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")

    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    config = BertConfig.from_pretrained(config_path)
    model = BertForPreTraining(config)

    layer_nums = []
    for full_name, shape in init_vars:
        array = tf.train.load_variable(tf_path, full_name)
        names.append(full_name)
        split_names = full_name.split("/")
        for name in split_names:
            if name.startswith("layer_"):
                layer_nums.append(int(name.split("_")[-1]))

        arrays.append(array)
    logger.info(f"Read a total of {len(arrays):,} layers")

    name_to_array = dict(zip(names, arrays))

    # Check that number of layers match
    assert config.num_hidden_layers == len(list(set(layer_nums)))

    state_dict = model.state_dict()

    # Need to do this explicitly as it is a buffer
    position_ids = state_dict["bert.embeddings.position_ids"]
    new_state_dict = {"bert.embeddings.position_ids": position_ids}

    # Encoder Layers
    for weight_name in names:
        pt_weight_name = weight_name.replace("kernel", "weight").replace("gamma", "weight").replace("beta", "bias")
        name_split = pt_weight_name.split("/")
        for name_idx, name in enumerate(name_split):
            if name.startswith("layer_"):
                name_split[name_idx] = name.replace("_", ".")

        if name_split[-1].endswith("embeddings"):
            name_split.append("weight")

        if name_split[0] == "cls":
            if name_split[-1] == "output_bias":
                name_split[-1] = "bias"
            if name_split[-1] == "output_weights":
                name_split[-1] = "weight"

        if name_split[-1] == "weight" and name_split[-2] == "dense":
            name_to_array[weight_name] = name_to_array[weight_name].T

        pt_weight_name = ".".join(name_split)

        new_state_dict[pt_weight_name] = torch.from_numpy(name_to_array[weight_name])

    new_state_dict["cls.predictions.decoder.weight"] = new_state_dict["bert.embeddings.word_embeddings.weight"].clone()
    new_state_dict["cls.predictions.decoder.bias"] = new_state_dict["cls.predictions.bias"].clone().T
    # Load State Dict
    model.load_state_dict(new_state_dict)

    # Save PreTrained
    logger.info(f"Saving pretrained model to {save_path}")
    model.save_pretrained(save_path)

    return model
예제 #3
0
    optimizer = AdamW(model.parameters(), lr=2e-5)
    model.train()
    train_losses = []
    for i in range(1, MAX_STEPS + 1):
        optimizer.zero_grad()
        sent_pairs = create_sent_pairs(sents_list, batch_size=BATCH_SIZE)
        encoded = encode_sent_pairs(sent_pairs)
        res = model(
            encoded["input_ids"].to(device),
            token_type_ids=None,
            attention_mask=encoded["attention_mask"].to(device),
            labels=encoded["labels"].to(device),
            next_sentence_label=encoded["next_sentence_label"].to(device),
        )
        loss = res.loss
        if i % 100 == 0:
            logger.info("training step {}, loss {}".format(i, loss))
            train_losses.append((i, loss.item()))
            df_train_loss = pd.DataFrame(train_losses,
                                         columns=["step", "train_loss"])
            df_train_loss.to_csv(output_base_dir / "train_loss.csv",
                                 index=False)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if i == 1 or i % 50000 == 0:
            save_dir = output_base_dir / "step{}".format(i)
            model.save_pretrained(save_directory=save_dir)