def load_stock_weights(bert: BertModelLayer, ckpt_path): """ Use this method to load the weights from a pre-trained BERT checkpoint into a bert layer. :param bert: a BertModelLayer instance within a built keras model. :param ckpt_path: checkpoint path, i.e. `uncased_L-12_H-768_A-12/bert_model.ckpt` or `albert_base_zh/albert_model.ckpt` :return: list of weights with mismatched shapes. This can be used to extend the segment/token_type embeddings. """ assert isinstance( bert, BertModelLayer ), "Expecting a BertModelLayer instance as first argument" assert _checkpoint_exists( ckpt_path), "Checkpoint does not exist: {}".format(ckpt_path) ckpt_reader = tf.train.load_checkpoint(ckpt_path) stock_weights = set(ckpt_reader.get_variable_to_dtype_map().keys()) prefix = bert_prefix(bert) loaded_weights = set() skip_count = 0 weight_value_tuples = [] skipped_weight_value_tuples = [] bert_params = bert.weights param_values = keras.backend.batch_get_value(bert.weights) for ndx, (param_value, param) in enumerate(zip(param_values, bert_params)): stock_name = map_to_tfhub_albert_variable_name(param.name, prefix) if ckpt_reader.has_tensor(stock_name): ckpt_value = ckpt_reader.get_tensor(stock_name) if param_value.shape != ckpt_value.shape: print( "loader: Skipping weight:[{}] as the weight shape:[{}] is not compatible " "with the checkpoint:[{}] shape:{}".format( param.name, param.shape, stock_name, ckpt_value.shape)) skipped_weight_value_tuples.append((param, ckpt_value)) continue weight_value_tuples.append((param, ckpt_value)) loaded_weights.add(stock_name) else: print("loader: No value for:[{}], i.e.:[{}] in:[{}]".format( param.name, stock_name, ckpt_path)) skip_count += 1 keras.backend.batch_set_value(weight_value_tuples) print("Done loading {} BERT weights from: {} into {} (prefix:{}). " "Count of weights not found in the checkpoint was: [{}]. " "Count of weights with mismatched shape: [{}]".format( len(weight_value_tuples), ckpt_path, bert, prefix, skip_count, len(skipped_weight_value_tuples))) print( "Unused weights from checkpoint:", "\n\t" + "\n\t".join(sorted(stock_weights.difference(loaded_weights)))) return skipped_weight_value_tuples # (bert_weight, value_from_ckpt)
def load_albert_weights(bert: BertModelLayer, tfhub_model_path, tags=[]): """ Use this method to load the weights from a pre-trained BERT checkpoint into a bert layer. :param bert: a BertModelLayer instance within a built keras model. :param ckpt_path: checkpoint path, i.e. `uncased_L-12_H-768_A-12/bert_model.ckpt` or `albert_base_zh/albert_model.ckpt` :return: list of weights with mismatched shapes. This can be used to extend the segment/token_type embeddings. """ if not _is_tfhub_model(tfhub_model_path): print("Loading brightmart/albert_zh weights...") return loader.load_stock_weights(bert, tfhub_model_path) assert isinstance(bert, BertModelLayer), "Expecting a BertModelLayer instance as first argument" prefix = loader.bert_prefix(bert) with tf.Graph().as_default(): sm = tf.compat.v2.saved_model.load(tfhub_model_path, tags=tags) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) stock_values = {v.name.split(":")[0]: v.read_value() for v in sm.variables} stock_values = sess.run(stock_values) # print("\n".join([str((n, v.shape)) for n,v in stock_values.items()])) loaded_weights = set() skip_count = 0 weight_value_tuples = [] skipped_weight_value_tuples = [] bert_params = bert.weights param_values = keras.backend.batch_get_value(bert.weights) for ndx, (param_value, param) in enumerate(zip(param_values, bert_params)): stock_name = map_to_tfhub_albert_variable_name(param.name, prefix) if stock_name in stock_values: ckpt_value = stock_values[stock_name] if param_value.shape != ckpt_value.shape: print("loader: Skipping weight:[{}] as the weight shape:[{}] is not compatible " "with the checkpoint:[{}] shape:{}".format(param.name, param.shape, stock_name, ckpt_value.shape)) skipped_weight_value_tuples.append((param, ckpt_value)) continue weight_value_tuples.append((param, ckpt_value)) loaded_weights.add(stock_name) else: print("loader: No value for:[{}], i.e.:[{}] in:[{}]".format(param.name, stock_name, tfhub_model_path)) skip_count += 1 keras.backend.batch_set_value(weight_value_tuples) print("Done loading {} BERT weights from: {} into {} (prefix:{}). " "Count of weights not found in the checkpoint was: [{}]. " "Count of weights with mismatched shape: [{}]".format( len(weight_value_tuples), tfhub_model_path, bert, prefix, skip_count, len(skipped_weight_value_tuples))) print("Unused weights from saved model:", "\n\t" + "\n\t".join(sorted(set(stock_values.keys()).difference(loaded_weights)))) return skipped_weight_value_tuples # (bert_weight, value_from_ckpt)