tokenizer = Tokenizer(dict_path, do_lower_case=True)  # 建立分词器
model = build_transformer_model(
    config_path=config_path, checkpoint_path=checkpoint_path, with_mlm=True
)  # 建立模型,加载权重

sentences = []
init_sent = u'科学技术是第一生产力。'  # 给定句子或者None
minlen, maxlen = 8, 32
steps = 10000
converged_steps = 1000
vocab_size = tokenizer._vocab_size

if init_sent is None:
    length = np.random.randint(minlen, maxlen + 1)
    tokens = ['[CLS]'] + ['[MASK]'] * length + ['[SEP]']
    token_ids = tokenizer.tokens_to_ids(tokens)
    segment_ids = [0] * len(token_ids)
else:
    token_ids, segment_ids = tokenizer.encode(init_sent)
    length = len(token_ids) - 2

for _ in tqdm(range(steps), desc='Sampling'):
    # Gibbs采样流程:随机mask掉一个token,然后通过MLM模型重新采样这个token。
    i = np.random.choice(length) + 1
    token_ids[i] = tokenizer._token_mask_id
    probas = model.predict(to_array([token_ids], [segment_ids]))[0, i]
    token = np.random.choice(vocab_size, p=probas)
    token_ids[i] = token
    sentences.append(tokenizer.decode(token_ids))

print(u'部分随机采样结果:')
'MOBA':'竞技',
'K歌':'唱歌',
'技术':'技术',
'减肥瘦身':'减肥',
'工作社交':'工作',
'团购':'团购',
'记账':'记账',
'女性':'女性',
'公务员':'公务',
'二手':'二手',
'美妆美业':'美妆',
'汽车咨询':'汽车', '行程管理':'行程',
'免费WIFI':'免费', '教辅':'教辅', '成人':'两性'}
labels=[label_zh for label_en,label_zh in label_en2zh.items()]
labels_en=[label_en for label_en,label_zh in label_en2zh.items()]
label_ids = [tuple(tokenizer.tokens_to_ids(label)) for label in labels]
num_labels = len(label_en2zh)
maxlen = 256
batch_size = 16
num_per_val_file = 148
acc_list = []
def load_data(filename, set_type): # 加载数据
    D = []
    with open(filename, encoding='utf-8') as f:
        for i, l in enumerate(f):
            l = json.loads(l)
            label_en=l['label_des']
            if label_en not in labels_en:
                continue
            label_zh=label_en2zh[label_en] # 将英文转化为中文
            if set_type == "train":
num_labeled = int(len(train_data) * train_frac)
# unlabeled_data = [(t, 2) for t, l in train_data[num_labeled:]]
train_data = train_data[:num_labeled]
print("1.num_labeled data used:", num_labeled, " ;train_data:",
      len(train_data))  # 168

# train_data = train_data + unlabeled_data

# 建立分词器
unused_length = 9  # 9
desc = [
    '[unused%s]' % i for i in range(1, unused_length)
]  # desc: ['[unused1]', '[unused2]', '[unused3]', '[unused4]', '[unused5]', '[unused6]', '[unused7]', '[unused8]', '[unused9]', '[unused10]']
desc_ids = [tokenizer.token_to_id(t) for t in desc]  # 将token转化为id
label_list = ["并且", "所以", "但是"]
label_tokenid_list = [tuple(tokenizer.tokens_to_ids(x)) for x in label_list]
# pos_id = tokenizer.token_to_id(u'很') # e.g. '[unused9]'. 将正向的token转化为id. 默认值:u'很'
# neg_id = tokenizer.token_to_id(u'不') # e.g. '[unused10]. 将负向的token转化为id. 默认值:u'不'


def random_masking(token_ids):
    """对输入进行mask
    在BERT中,mask比例为15%,相比auto-encoder,BERT只预测mask的token,而不是重构整个输入token。
    mask过程如下:80%机会使用[MASK],10%机会使用原词,10%机会使用随机词。
    """
    rands = np.random.random(
        len(token_ids)
    )  # rands: array([-0.34792592,  0.13826393,  0.8567176 ,  0.32175848, -1.29532141, -0.98499201, -1.11829718,  1.18344819,  1.53478554,  0.24134646])
    source, target = [], []
    for r, t in zip(rands, token_ids):
        if r < 0.15 * 0.8:  # 80%机会使用[MASK]
class AlbertNerModel(object):
    # model=None
    def __init__(self,
                 model_name: str,
                 path: str,
                 config_path: str,
                 checkpoint_path: str,
                 dict_path: str,
                 layers: int = 0,
                 unshared: bool = False):
        """
        Albert 初始化参数
        :param model_name: 模型名称,albert_base/albert_small/albert_tiny, 不推荐albertbase/albertsmall/alberttiny
        :param path: 权重路径
        :param config_path: 预训练模型配置文件
        :param checkpoint_path: 预训练模型文件
        :param dict_path: 预训练模型字典
        :param layers: 可选自定义层数,base最大12层,small最大6层,tiny最大4层
        :param unshared: 是否以Bert形式做层分解,默认为否
        """
        if tf.__version__ >= '2.0':
            print('暂不支持tensorflow 2.0 以上版本')
            raise
        self.weight_path = path
        self.__maxlen = 256
        self.__crf_lr_multiplier = 1000

        if str(model_name).upper() == 'ALBERT_BASE' or str(
                model_name).upper() == 'ALBERTBASE':
            self.albert_layers = 12
        elif str(model_name).upper() == 'ALBERT_SMALL' or str(
                model_name).upper() == 'ALBERTSMALL':
            self.albert_layers = 6
        elif str(model_name).upper() == 'ALBERT_TINY' or str(
                model_name).upper() == 'ALBERTTINY':
            self.albert_layers = 4
        if layers > 0:
            self.albert_layers = layers
        self.pretrain_name = model_name
        self.config = config_path
        self.checkpoint = checkpoint_path
        self.dict = dict_path
        self.unshared = unshared

        self.tokenizer = Tokenizer(self.dict, do_lower_case=True)
        # 类别映射
        labels = ['PER', 'LOC', 'ORG']
        id2label = dict(enumerate(labels))
        # label2id={j: i for i,j in id2label.items()}
        self.__id2label = id2label
        self.__num_labels = len(labels) * 2 + 1
        # label2id = {j: i for i, j in id2label.items()}
        assert self.config and self.checkpoint and self.dict
        # self.__crf= ConditionalRandomField(lr_multiplier=self.crf_lr_multiplier)
        self.__crf = None
        self._model = None

# region 为便于多模型配置调试,对所有配置参数做setter处理,配置完毕需要重新build model

    def set_layers(self, value):
        self.albert_layers = value

    def set_unshared(self, value):
        self.unshared = value

    def set_dict_path(self, path):
        self.dict = path
        self.tokenizer = Tokenizer(self.dict, do_lower_case=True)

    def set_checkpoint_path(self, path):
        self.checkpoint = path

    def set_config_path(self, path):
        self.config = path

    def set_weight_path(self, weight_path):
        self.weight_path = weight_path
# endregion

    @property
    def maxlen(self):
        return self.__maxlen

    @maxlen.setter
    def maxlen(self, value):
        self.__maxlen = value

    @property
    def crf_lr_multiplier(self):
        return self.__crf_lr_multiplier

    @crf_lr_multiplier.setter
    def crf_lr_multiplier(self, value):
        self.__crf_lr_multiplier = value

    @property
    def albert_model(self):
        return self._model

    @albert_model.setter
    def albert_model(self, model_path: str):
        from keras.models import load_model
        from keras.utils import CustomObjectScope
        # self.__model=load_model(model_path,custom_objects={'ConditionalRandomField':
        #                             ConditionalRandomField,
        #                         'sparse_loss':ConditionalRandomField.sparse_loss},
        #                         compile=False)##两种自定义loss加载方式均可
        with CustomObjectScope({
                'ConditionalRandomField':
                ConditionalRandomField,
                'sparse_loss':
                ConditionalRandomField.sparse_loss
        }):
            self._model = load_model(model_path)
            ##此处是重点!!,本机电脑及服务器上model中crf层名字如下,实际情况若名称不一致,需根据模型拓扑结构中的名字更改!!!
            self.__crf = self._model.get_layer('conditional_random_field_1')
            assert isinstance(self.__crf, ConditionalRandomField)

    @albert_model.deleter
    def albert_model(self):
        K.clear_session()
        del self._model

    def build_albert_model(self):
        del self.albert_model
        file_name = f'albert_{self.pretrain_name}_pretrain.h5'  ##这里,为了方便预训练模型加载,我预先将加载后的预训练模型保存为了.h5
        if os.path.exists(file_name):
            pretrain_model = load_model(file_name, compile=False)
        else:
            pretrain_model = build_transformer_model(
                config_path=self.config,
                checkpoint_path=self.checkpoint,
                model='albert_unshared' if self.unshared else 'albert',
                return_keras_model=True)

        if not self.unshared:
            output_layer = 'Transformer-FeedForward-Norm'
            output = pretrain_model.get_layer(output_layer).get_output_at(
                self.albert_layers - 1)
        else:
            output_layer = 'Transformer-%s-FeedForward-Norm' % (
                self.albert_layers - 1)
            output = pretrain_model.get_layer(output_layer).output
        output = Dense(self.__num_labels)(output)
        self.__crf = ConditionalRandomField(
            lr_multiplier=self.crf_lr_multiplier)
        output = self.__crf(output)
        model = Model(pretrain_model.input, output)
        model.load_weights(self.weight_path)
        self._model = model

    def viterbi_decode(self, nodes, trans, starts=[0], ends=[0]):
        """Viterbi算法求最优路径
        """
        num_labels = len(trans)
        non_starts = []
        non_ends = []
        if starts is not None:
            for i in range(num_labels):
                if i not in starts:
                    non_starts.append(i)
        if ends is not None:
            for i in range(num_labels):
                if i not in ends:
                    non_ends.append(i)
                # 预处理
        nodes[0, non_starts] -= np.inf
        nodes[-1, non_ends] -= np.inf
        labels = np.arange(num_labels).reshape((1, -1))
        scores = nodes[0].reshape((-1, 1))
        # scores[1:] -= np.inf  # 第一个标签必然是0
        paths = labels
        for l in range(1, len(nodes)):
            M = scores + trans + nodes[l].reshape((1, -1))
            idxs = M.argmax(0)
            scores = M.max(0).reshape((-1, 1))
            paths = np.concatenate([paths[:, idxs], labels], 0)
        return paths[:, scores[:, 0].argmax()]  # 最优路径

    def recognize(self, text):
        """
        # 识别实体
        :param text:
        :return: entities list
        """
        tokens = self.tokenizer.tokenize(text)
        while len(tokens) > 512:
            tokens.pop(-2)
        try:
            mapping = self.tokenizer.rematch(text, tokens)
            token_ids = self.tokenizer.tokens_to_ids(tokens)
            segment_ids = [0] * len(token_ids)
            nodes = self._model.predict([[token_ids], [segment_ids]])[0]
            # print('nodes:',nodes)
            _trans = K.eval(self.__crf.trans)
            labels = self.viterbi_decode(nodes, trans=_trans)
            entities, starting = [], False
            for i, label in enumerate(labels):
                if label > 0:
                    if label % 2 == 1:
                        starting = True
                        entities.append([[i],
                                         self.__id2label[(label - 1) // 2]])
                    elif starting:
                        entities[-1][0].append(i)
                    else:
                        starting = False
                else:
                    starting = False

            return [(text[mapping[w[0]][0]:mapping[w[-1]][-1] + 1], l)
                    for w, l in entities]
        except:
            import traceback
            traceback.print_exc()