コード例 #1
0
def _our_map_to_stock_variable_name(name, prefix="bert_mlm/bert"):
    real_name = name.split(":")[0]
    if real_name.startswith(f"{prefix}/cls/predictions"):
        # These are the MLM prediction variables that are in the BERT checkpoint but not in
        # the tf2 bert implementation.
        return real_name[len(prefix) + 1 :]
    return tf2_bert_loader.map_to_stock_variable_name(name, prefix)
コード例 #2
0
    def load_stock_weights(self, bert: BertModelLayer, ckpt_file):
        assert isinstance(
            bert, BertModelLayer
        ), "Expecting a BertModelLayer instance as first argument"
        assert tf.compat.v1.train.checkpoint_exists(
            ckpt_file), "Checkpoint does not exist: {}".format(ckpt_file)
        ckpt_reader = tf.train.load_checkpoint(ckpt_file)

        bert_prefix = 'transformer/bert'

        weights = []
        for weight in bert.weights:
            stock_name = map_to_stock_variable_name(weight.name, bert_prefix)
            if ckpt_reader.has_tensor(stock_name):
                value = ckpt_reader.get_tensor(stock_name)
                weights.append(value)
            else:
                raise ValueError("No value for:[{}], i.e.:[{}] in:[{}]".format(
                    weight.name, stock_name, ckpt_file))
        bert.set_weights(weights)
        print(
            "Done loading {} BERT weights from: {} into {} (prefix:{})".format(
                len(weights), ckpt_file, bert, bert_prefix))
コード例 #3
0
    def test___compare_weights(self):

        #tf.reset_default_graph()

        max_seq_len = 18
        model, bert, inputs = self.create_bert_model(18)
        model.build(input_shape=[(None, max_seq_len), (None, max_seq_len)])

        stock_vars = tf.train.list_variables(self.bert_ckpt_file)
        stock_vars = {name: list(shape) for name, shape in stock_vars}

        keras_vars = model.trainable_variables
        keras_vars = {
            var.name.split(":")[0]: var.shape.as_list()
            for var in keras_vars
        }

        matched_vars = set()
        unmatched_vars = set()
        shape_errors = set()

        for name in stock_vars:
            bert_name = name
            keras_name = map_from_stock_variale_name(bert_name)
            if keras_name in keras_vars:
                if keras_vars[keras_name] == stock_vars[bert_name]:
                    matched_vars.add(bert_name)
                else:
                    shape_errors.add(bert_name)
            else:
                unmatched_vars.add(bert_name)

        print("bert -> keras:")
        print("     matched count:", len(matched_vars))
        print("   unmatched count:", len(unmatched_vars))
        print(" shape error count:", len(shape_errors))

        print("unmatched:\n", "\n ".join(unmatched_vars))

        self.assertEqual(197, len(matched_vars))
        self.assertEqual(9, len(unmatched_vars))
        self.assertEqual(0, len(shape_errors))

        matched_vars = set()
        unmatched_vars = set()
        shape_errors = set()

        for name in keras_vars:
            keras_name = name
            bert_name = map_to_stock_variable_name(keras_name)
            if bert_name in stock_vars:
                if stock_vars[bert_name] == keras_vars[keras_name]:
                    matched_vars.add(keras_name)
                else:
                    shape_errors.add(keras_name)
            else:
                unmatched_vars.add(keras_name)

        print("keras -> bert:")
        print("     matched count:", len(matched_vars))
        print("   unmatched count:", len(unmatched_vars))
        print(" shape error count:", len(shape_errors))

        print("unmatched:\n", "\n ".join(unmatched_vars))
        self.assertEqual(197, len(matched_vars))
        self.assertEqual(0, len(unmatched_vars))
        self.assertEqual(0, len(shape_errors))