def _our_map_to_stock_variable_name(name, prefix="bert_mlm/bert"): real_name = name.split(":")[0] if real_name.startswith(f"{prefix}/cls/predictions"): # These are the MLM prediction variables that are in the BERT checkpoint but not in # the tf2 bert implementation. return real_name[len(prefix) + 1 :] return tf2_bert_loader.map_to_stock_variable_name(name, prefix)
def load_stock_weights(self, bert: BertModelLayer, ckpt_file): assert isinstance( bert, BertModelLayer ), "Expecting a BertModelLayer instance as first argument" assert tf.compat.v1.train.checkpoint_exists( ckpt_file), "Checkpoint does not exist: {}".format(ckpt_file) ckpt_reader = tf.train.load_checkpoint(ckpt_file) bert_prefix = 'transformer/bert' weights = [] for weight in bert.weights: stock_name = map_to_stock_variable_name(weight.name, bert_prefix) if ckpt_reader.has_tensor(stock_name): value = ckpt_reader.get_tensor(stock_name) weights.append(value) else: raise ValueError("No value for:[{}], i.e.:[{}] in:[{}]".format( weight.name, stock_name, ckpt_file)) bert.set_weights(weights) print( "Done loading {} BERT weights from: {} into {} (prefix:{})".format( len(weights), ckpt_file, bert, bert_prefix))
def test___compare_weights(self): #tf.reset_default_graph() max_seq_len = 18 model, bert, inputs = self.create_bert_model(18) model.build(input_shape=[(None, max_seq_len), (None, max_seq_len)]) stock_vars = tf.train.list_variables(self.bert_ckpt_file) stock_vars = {name: list(shape) for name, shape in stock_vars} keras_vars = model.trainable_variables keras_vars = { var.name.split(":")[0]: var.shape.as_list() for var in keras_vars } matched_vars = set() unmatched_vars = set() shape_errors = set() for name in stock_vars: bert_name = name keras_name = map_from_stock_variale_name(bert_name) if keras_name in keras_vars: if keras_vars[keras_name] == stock_vars[bert_name]: matched_vars.add(bert_name) else: shape_errors.add(bert_name) else: unmatched_vars.add(bert_name) print("bert -> keras:") print(" matched count:", len(matched_vars)) print(" unmatched count:", len(unmatched_vars)) print(" shape error count:", len(shape_errors)) print("unmatched:\n", "\n ".join(unmatched_vars)) self.assertEqual(197, len(matched_vars)) self.assertEqual(9, len(unmatched_vars)) self.assertEqual(0, len(shape_errors)) matched_vars = set() unmatched_vars = set() shape_errors = set() for name in keras_vars: keras_name = name bert_name = map_to_stock_variable_name(keras_name) if bert_name in stock_vars: if stock_vars[bert_name] == keras_vars[keras_name]: matched_vars.add(keras_name) else: shape_errors.add(keras_name) else: unmatched_vars.add(keras_name) print("keras -> bert:") print(" matched count:", len(matched_vars)) print(" unmatched count:", len(unmatched_vars)) print(" shape error count:", len(shape_errors)) print("unmatched:\n", "\n ".join(unmatched_vars)) self.assertEqual(197, len(matched_vars)) self.assertEqual(0, len(unmatched_vars)) self.assertEqual(0, len(shape_errors))