def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path): # Initialise PyTorch model config = XLNetConfig.from_json_file(bert_config_file) model = XLNetLMHeadModel(config) # Load weights from tf checkpoint load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) # Save pytorch-model pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) torch.save(model.state_dict(), pytorch_weights_dump_path) print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: f.write(config.to_json_string())
def convert_pytorch_checkpoint_to_tf(model: XLNetLMHeadModel, ckpt_dir: str, model_name: str): """ :param model:XLNetLMHeadModel Pytorch model instance to be converted :param ckpt_dir: Tensorflow model directory :param model_name: model name :return: Currently supported HF models: XLNetLMHeadModel """ tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") var_map = (('layer.', 'layer_'), ('word_embeddings.weight', 'word_embeddings'), ('position_embeddings.weight', 'position_embeddings'), ('token_type_embeddings.weight', 'token_type_embeddings'), ('.', '/'), ('LayerNorm/weight', 'LayerNorm/gamma'), ('LayerNorm/bias', 'LayerNorm/beta'), ('weight', 'kernel')) if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir) state_dict = model.state_dict() def to_tf_var_name(name: str): for patt, repl in iter(var_map): name = name.replace(patt, repl) return 'xlnet/{}'.format(name) def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session): tf_dtype = tf.dtypes.as_dtype(tensor.dtype) tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) session.run(tf.variables_initializer([tf_var])) session.run(tf_var) return tf_var tf.reset_default_graph() with tf.Session() as session: for var_name in state_dict: tf_name = to_tf_var_name(var_name) torch_tensor = state_dict[var_name].numpy() if any([x in var_name for x in tensors_to_transpose]): torch_tensor = torch_tensor.T tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) tf.keras.backend.set_value(tf_var, torch_tensor) tf_weight = session.run(tf_var) print("Successfully created {}: {}".format( tf_name, np.allclose(tf_weight, torch_tensor))) saver = tf.train.Saver(tf.trainable_variables()) saver.save( session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))