def get_generator(sentence_generator: Generator[SentenceBatch, None, None], is_pretrain: bool): for i, batch in enumerate(sentence_generator): batch_size, seq_len = batch.tokens.shape x = [batch.tokens, batch.segments, generate_pos_ids(batch_size, max_len)] y = [] if uses_attn_mask: x.append(create_attention_mask(batch.padding_mask, is_causal)) for task_name in task_nodes.keys(): if is_pretrain: cond = all_tasks[task_name].weight_scheduler.active_in_pretrain else: cond = all_tasks[task_name].weight_scheduler.active_in_finetune if cond: if task_name in batch.sentence_classification: task_data_batch = batch.sentence_classification[task_name] else: task_data_batch = batch.token_classification[task_name] x.append(task_data_batch.target) if all_tasks[task_name].is_token_level: x.append(task_data_batch.target_mask) else: x.append((task_data_batch.target_mask + np.arange(batch_size) * seq_len).astype(np.int32)) x.append( np.repeat(np.array( [all_tasks[task_name].weight_scheduler.get(is_pretrain, i)]), batch_size, 0)) y.append(np.repeat(np.array([0.0]), batch_size, 0)) yield x, y
def test_same_result(self): base_location = './google_bert/downloads/multilingual_L-12_H-768_A-12/' bert_config = BertConfig.from_json_file(base_location + 'bert_config.json') init_checkpoint = base_location + 'bert_model.ckpt' def model_fn_builder(bert_config, init_checkpoint): """Returns `model_fn` closure for TPUEstimator.""" def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" unique_ids = features["unique_ids"] input_ids = features["input_ids"] input_mask = features["input_mask"] input_type_ids = features["input_type_ids"] model = BertModel(config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=False) if mode != tf.estimator.ModeKeys.PREDICT: raise ValueError("Only PREDICT modes are supported: %s" % (mode)) tvars = tf.trainable_variables() scaffold_fn = None (assignment_map, _) = get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) predictions = { "unique_id": unique_ids, "seq_out": model.get_sequence_output() } output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec return model_fn batch_size = 8 seq_len = 5 xmb = np.random.randint(106, bert_config.vocab_size - 106, (batch_size, seq_len)) xmb2 = np.random.randint(0, 2, (batch_size, seq_len), dtype=np.int32) xmb3 = np.random.randint(0, 2, (batch_size, seq_len), dtype=np.int32) def input_fn(params): d = tf.data.Dataset.from_tensor_slices({ "unique_ids": tf.constant([0, 1, 2], shape=[batch_size], dtype=tf.int32), "input_ids": tf.constant(xmb, shape=[batch_size, seq_len], dtype=tf.int32), "input_mask": tf.constant(xmb2, shape=[batch_size, seq_len], dtype=tf.int32), "input_type_ids": tf.constant(xmb3, shape=[batch_size, seq_len], dtype=tf.int32), }) d = d.batch(batch_size=batch_size, drop_remainder=False) return d model_fn = model_fn_builder(bert_config=bert_config, init_checkpoint=init_checkpoint) is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( master=None, tpu_config=tf.contrib.tpu.TPUConfig( num_shards=8, per_host_input_for_training=is_per_host)) estimator = tf.contrib.tpu.TPUEstimator(use_tpu=False, model_fn=model_fn, config=run_config, predict_batch_size=batch_size) tf_result = [r for r in estimator.predict(input_fn)] import tensorflow.keras.backend as K K.set_learning_phase(0) my_model = load_google_bert(base_location, max_len=seq_len) from data.dataset import create_attention_mask, generate_pos_ids pos = generate_pos_ids(batch_size, seq_len) k_mask = create_attention_mask(xmb2, False, None, None, True) bert_encoder = BERTTextEncoder(base_location + 'vocab.txt') for b in range(len(xmb)): xmb[b] = np.array(bert_encoder.standardize_ids(xmb[b].tolist())) k_output = my_model.predict([xmb, xmb3, pos, k_mask]) max_max = 0 for i in range(batch_size): if k_mask[i].mean( ) != 0: # TODO (when mask == full zero, keras_res != tf_res) new_max = np.abs(k_output[i] - tf_result[i]['seq_out']).max() if new_max > max_max: max_max = new_max assert max_max < 5e-5, max_max # TODO reduce the error (I think it's because of the LayerNorm)
def re_data_to_input_output(model_data, bert_encoder, max_len): # 输入构造 input_ids = [] input_masks = [] input_entity_masks = [] output_re_ids = [] output_ne_ids = [] for item in tqdm(model_data): text = item['text'] entities = item['entities'] relations = item['relations'] tokens, char_to_word_offset = bert_encoder.tokenize(text) maped_entities = [] for start_pos, end_pos, _, entity_type in entities: map_start_pos = start_pos while char_to_word_offset[map_start_pos] is None: map_start_pos += 1 map_start_pos = char_to_word_offset[map_start_pos] map_end_pos = end_pos - 1 while char_to_word_offset[map_end_pos] is None: map_end_pos += 1 map_end_pos = char_to_word_offset[map_end_pos] + 1 maped_entities.append((map_start_pos, map_end_pos, entity_type)) relation_map = {(start_entity_idx, end_entity_idx):relation_type for start_entity_idx, end_entity_idx, relation_type in relations} netype_list = ['O' for _ in range(len(tokens))] for start_pos, end_pos, entity_type in maped_entities: if end_pos - start_pos == 1: netype_list[start_pos] = "S-" + entity_type else: netype_list[start_pos] = "B-" + entity_type netype_list[end_pos - 1] = "E-" + entity_type for tmp_pos in range(start_pos + 1, end_pos - 1): netype_list[tmp_pos] = "I-" + entity_type for start_entity_idx in range(len(entities)): for end_entity_idx in range(start_entity_idx + 1, len(entities)): remask = [0 for _ in range(len(tokens))] for tmp_pos in range(maped_entities[start_entity_idx][0], maped_entities[start_entity_idx][1]): remask[tmp_pos] = 1 for tmp_pos in range(maped_entities[end_entity_idx][0], maped_entities[end_entity_idx][1]): remask[tmp_pos] = 1 retype = "" if (start_entity_idx, end_entity_idx) in relation_map: retype = relation_map[(start_entity_idx, end_entity_idx)] else: retype = 'no_relation' input_tokens = ['[CLS]'] + tokens + ['[SEP]'] input_id = bert_encoder.standardize_ids(bert_encoder.convert_tokens_to_ids(input_tokens)) input_mask = [1] * len(input_tokens) input_entity_mask = [1] + remask + [0] output_re_id = retype2id[retype] output_netype_list = ['null'] + netype_list + ['null'] output_ne_id = [netype2id[item] for item in output_netype_list] input_id += [0] * (max_len - len(input_id)) input_mask += [0] * (max_len - len(input_mask)) input_entity_mask += [0] * (max_len - len(input_entity_mask)) output_ne_id += [0] * (max_len - len(output_ne_id)) input_ids.append(input_id) input_masks.append(input_mask) input_entity_masks.append(input_entity_mask) output_re_ids.append(output_re_id) output_ne_ids.append(output_ne_id) input_ids=np.array(input_ids, dtype=np.int32) input_masks=np.array(input_masks, dtype=np.int32) input_entity_masks=np.array(input_entity_masks, dtype=np.int32) output_re_ids=np.array(output_re_ids, dtype=np.int32) output_ne_ids=np.array(output_ne_ids, dtype=np.int32) pos = generate_pos_ids(len(input_ids), max_len) input_type_ids = np.zeros((len(input_ids), max_len), dtype=np.int32) return [input_ids, input_type_ids, pos, input_masks, input_entity_masks], [output_re_ids, output_ne_ids]
def eval_model(model_data, model, max_len): #真确答案保存 mapped_ner_true = [] mapped_re_true = [] # 先计算实体 ,构造实体预测输入 input_ids = [] input_masks = [] input_type_ids = np.zeros((len(model_data), max_len), dtype=np.int32) input_masks = [] input_entity_masks = [] for item in model_data: text = item['text'] entities = item['entities'] relations = item['relations'] tokens, char_to_word_offset = bert_encoder.tokenize(text) maped_entities = [] for start_pos, end_pos, _, entity_type in entities: map_start_pos = start_pos while char_to_word_offset[map_start_pos] is None: map_start_pos += 1 map_start_pos = char_to_word_offset[map_start_pos] map_end_pos = end_pos - 1 while char_to_word_offset[map_end_pos] is None: map_end_pos += 1 map_end_pos = char_to_word_offset[map_end_pos] + 1 maped_entities.append((map_start_pos, map_end_pos, entity_type)) mapped_ner_true.append(maped_entities) mapped_re_true.append([ tuple( list(maped_entities[start_entity_idx]) + list(maped_entities[end_entity_idx]) + [relation_type]) for start_entity_idx, end_entity_idx, relation_type in relations ]) input_tokens = ['[CLS]'] + tokens + ['[SEP]'] input_id = bert_encoder.standardize_ids( bert_encoder.convert_tokens_to_ids(input_tokens)) input_mask = [1] * len(input_tokens) input_entity_mask = [0] * len(input_id) input_id += [0] * (max_len - len(input_id)) input_mask += [0] * (max_len - len(input_mask)) input_entity_mask += [0] * (max_len - len(input_entity_mask)) input_ids.append(input_id) input_masks.append(input_mask) input_entity_masks.append(input_entity_mask) input_ids = np.array(input_ids, dtype=np.int32) input_type_ids = np.array(input_type_ids, dtype=np.int32) input_masks = np.array(input_masks, dtype=np.int32) input_entity_masks = np.array(input_entity_masks, dtype=np.int32) pos = generate_pos_ids(len(model_data), max_len) x = [input_ids, input_type_ids, pos, input_masks, input_entity_masks] #预测实体 y_predict = model.predict(x, batch_size=128, verbose=1) ne_predict = [] for item in np.argmax(y_predict[1], axis=-1): BIOES_list = [] #删除头部CLS for item in item[1:]: if item == 0: BIOES_list.append("O") else: BIOES_list.append(neid2type[item]) ne_predict.append([ (start, end + 1, entity_type) for entity_type, start, end in get_entities(BIOES_list) ]) # 计算正确率 ne_f1 = eval_prf(ne_predict, mapped_ner_true) # 预测re re_predict_map = {} # 输入构造 input_ids = [] input_masks = [] input_type_ids = [] input_masks = [] input_entity_masks = [] for data_idx, item in enumerate(model_data): text = item['text'] tokens, char_to_word_offset = bert_encoder.tokenize(text) maped_entities = ne_predict[data_idx] re_index = 0 for start_entity_idx in range(len(maped_entities)): for end_entity_idx in range(start_entity_idx + 1, len(maped_entities)): # remask = [0 for _ in range(len(tokens))] # for tmp_pos in range(maped_entities[start_entity_idx][0], maped_entities[start_entity_idx][1]): # remask[tmp_pos] = 1 # for tmp_pos in range(maped_entities[end_entity_idx][0], maped_entities[end_entity_idx][1]): # remask[tmp_pos] = 1 # remask_list.append(remask) input_tokens = ['[CLS]'] + tokens + ['[SEP]'] input_id = bert_encoder.standardize_ids( bert_encoder.convert_tokens_to_ids(input_tokens)) input_mask = [1] * len(input_tokens) input_entity_mask = [0] * len(input_id) input_entity_mask[0] = 1 for tmp_pos in range(maped_entities[start_entity_idx][0], maped_entities[start_entity_idx][1]): input_entity_mask[tmp_pos + 1] = 1 for tmp_pos in range(maped_entities[end_entity_idx][0], maped_entities[end_entity_idx][1]): input_entity_mask[tmp_pos + 1] = 1 input_id += [0] * (max_len - len(input_id)) input_mask += [0] * (max_len - len(input_mask)) input_entity_mask += [0] * (max_len - len(input_entity_mask)) input_ids.append(input_id) input_masks.append(input_mask) input_entity_masks.append(input_entity_mask) re_predict_map[( data_idx, len(input_ids) - 1)] = list(maped_entities[start_entity_idx]) + list( maped_entities[end_entity_idx]) re_index += 1 input_ids = np.array(input_ids, dtype=np.int32) input_type_ids = np.zeros((len(input_ids), max_len), dtype=np.int32) input_masks = np.array(input_masks, dtype=np.int32) input_entity_masks = np.array(input_entity_masks, dtype=np.int32) pos = generate_pos_ids(len(input_ids), max_len) x = [input_ids, input_type_ids, pos, input_masks, input_entity_masks] #预测关系 y_predict = model.predict(x, batch_size=128, verbose=1) # re_predict构造 re_predict = [] for key, value in re_predict_map.items(): data_idx = key[0] relation_idx = key[1] while data_idx >= len(re_predict): re_predict.append([]) re = np.argmax(y_predict[0][relation_idx], axis=-1) if re >= 2: # 去除 null 和no relation re_predict[data_idx].append(tuple(value + [reid2type[re]])) re_f1 = eval_prf(re_predict, mapped_re_true) return ne_f1, re_f1