Ejemplo n.º 1
0
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path):
    # Initialise PyTorch model
    config = XLNetConfig.from_json_file(bert_config_file)

    model = XLNetLMHeadModel(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)

    # Save pytorch-model
    pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
    pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
    print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
    torch.save(model.state_dict(), pytorch_weights_dump_path)
    print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
    with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
        f.write(config.to_json_string())
Ejemplo n.º 2
0
def convert_pytorch_checkpoint_to_tf(model: XLNetLMHeadModel, ckpt_dir: str,
                                     model_name: str):
    """
    :param model:XLNetLMHeadModel Pytorch model instance to be converted
    :param ckpt_dir: Tensorflow model directory
    :param model_name: model name
    :return:

    Currently supported HF models:
        XLNetLMHeadModel
    """

    tensors_to_transpose = ("dense.weight", "attention.self.query",
                            "attention.self.key", "attention.self.value")

    var_map = (('layer.', 'layer_'), ('word_embeddings.weight',
                                      'word_embeddings'),
               ('position_embeddings.weight', 'position_embeddings'),
               ('token_type_embeddings.weight', 'token_type_embeddings'),
               ('.', '/'), ('LayerNorm/weight', 'LayerNorm/gamma'),
               ('LayerNorm/bias', 'LayerNorm/beta'), ('weight', 'kernel'))

    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    state_dict = model.state_dict()

    def to_tf_var_name(name: str):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return 'xlnet/{}'.format(name)

    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype,
                                 shape=tensor.shape,
                                 name=name,
                                 initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    tf.reset_default_graph()
    with tf.Session() as session:
        for var_name in state_dict:
            tf_name = to_tf_var_name(var_name)
            torch_tensor = state_dict[var_name].numpy()
            if any([x in var_name for x in tensors_to_transpose]):
                torch_tensor = torch_tensor.T
            tf_var = create_tf_var(tensor=torch_tensor,
                                   name=tf_name,
                                   session=session)
            tf.keras.backend.set_value(tf_var, torch_tensor)
            tf_weight = session.run(tf_var)
            print("Successfully created {}: {}".format(
                tf_name, np.allclose(tf_weight, torch_tensor)))

        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(
            session,
            os.path.join(ckpt_dir,
                         model_name.replace("-", "_") + ".ckpt"))