def collate(examples): p_ids, examples = zip(*examples) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) batch_token_ids, batch_segment_ids = [], [] batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], [] for example in examples: spoes = example.spoes token_ids = self.tokenizer.encode(example.bert_tokens)[1:-1] segment_ids = len(token_ids) * [0] if self.is_train: if spoes: # subject标签 token_type_ids = np.zeros(len(token_ids), dtype=np.long) subject_labels = np.zeros((len(token_ids), 2), dtype=np.float32) for s in spoes: subject_labels[s[0], 0] = 1 subject_labels[s[1], 1] = 1 # 随机选一个subject subject_ids = random.choice(list(spoes.keys())) # 对应的object标签 object_labels = np.zeros( (len(token_ids), len(self.spo_config), 2), dtype=np.float32) for o in spoes.get(subject_ids, []): object_labels[o[0], o[2], 0] = 1 object_labels[o[1], o[2], 1] = 1 batch_token_ids.append(token_ids) batch_token_type_ids.append(token_type_ids) batch_segment_ids.append(segment_ids) batch_subject_labels.append(subject_labels) batch_subject_ids.append(subject_ids) batch_object_labels.append(object_labels) else: batch_token_ids.append(token_ids) batch_segment_ids.append(segment_ids) batch_token_ids = sequence_padding(batch_token_ids, is_float=False) batch_segment_ids = sequence_padding(batch_segment_ids, is_float=False) if not self.is_train: return p_ids, batch_token_ids, batch_segment_ids else: batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False) batch_subject_ids = torch.tensor(batch_subject_ids) batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2), is_float=True) batch_object_labels = sequence_padding( batch_object_labels, padding=np.zeros((len(self.spo_config), 2)), is_float=True) return batch_token_ids, batch_segment_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
def collate(examples): p_ids, examples = zip(*examples) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) batch_char_ids, batch_word_ids = [], [] batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], [] for example in examples: # todo maxlen char_ids = [self.char2idx.get(char, 1) for char in example.context] word_ids = [self.word2idx.get(word, 0) for word in example.text_word for _ in word] if len(char_ids) != len(word_ids): print(example.context) print(char_ids) print(len(char_ids)) print(example.text_word) print(word_ids) print(len(word_ids)) assert len(char_ids) == len(word_ids) char_ids = char_ids[:self.max_len] word_ids = word_ids[:self.max_len] # example.context = example.context[:self.max_len] if self.is_train: subject_labels = np.zeros((len(char_ids), 2), dtype=np.float32) token_type_ids = np.zeros(len(char_ids), dtype=np.long) object_labels = np.zeros((len(char_ids), len(DRUG_RELATION), 2), dtype=np.float32) for s in example.ent_list: subject_labels[s[0], 0] = 1 subject_labels[s[1] - 1, 1] = 1 if example.sub_entity_list: subject_ids = example.sub_entity_list token_type_ids[subject_ids[0]:subject_ids[1] + 1] = 1 # 对应的object标签 for o in example.spoes.get(subject_ids, []): object_labels[o[0], o[2], 0] = 1 object_labels[o[1], o[2], 1] = 1 batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_token_type_ids.append(token_type_ids) batch_subject_labels.append(subject_labels) batch_subject_ids.append(subject_ids) batch_object_labels.append(object_labels) else: batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_char_ids = sequence_padding(batch_char_ids, is_float=False) batch_word_ids = sequence_padding(batch_word_ids, is_float=False) if not self.is_train: return p_ids, batch_char_ids, batch_word_ids else: batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False) batch_subject_ids = torch.tensor(batch_subject_ids) batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2), is_float=True) batch_object_labels = sequence_padding(batch_object_labels, padding=np.zeros((len(DRUG_RELATION), 2)), is_float=True) return batch_char_ids, batch_word_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
def text2id(args, char2idx, raw_text: list): p_ids = [] batch_char_ids = [] examples = [] for index, text_ in enumerate(raw_text): p_ids.append(index) char_ids = [char2idx.get(char, 1) for char in text_] batch_char_ids.append(char_ids) examples.append(Example( p_id=index, context=text_, )) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) text_ids = sequence_padding(batch_char_ids, is_float=False) return examples, (p_ids, text_ids)
def collate(examples): p_ids, examples = zip(*examples) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) batch_token_ids, batch_segment_ids = [], [] batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], [] for example in examples: # todo maxlen token_ids, segment_ids = self.tokenizer.encode( example.context, max_length=self.max_len) example.bert_tokens = self.tokenizer.tokenize(example.context) example.token_ids = token_ids if self.is_train: spoes = {} for s, p, o in example.gold_answer: s = self.tokenizer.encode(s)[0][1:-1] p = BAIDU_RELATION[p] o = self.tokenizer.encode(o)[0][1:-1] s_idx = search(s, token_ids) o_idx = search(o, token_ids) if s_idx != -1 and o_idx != -1: s = (s_idx, s_idx + len(s) - 1) o = (o_idx, o_idx + len(o) - 1, p) if s not in spoes: spoes[s] = [] spoes[s].append(o) if spoes: # subject标签 token_type_ids = np.zeros(len(token_ids), dtype=np.long) subject_labels = np.zeros((len(token_ids), 2), dtype=np.float32) for s in spoes: subject_labels[s[0], 0] = 1 subject_labels[s[1], 1] = 1 # 随机选一个subject start, end = np.array(list(spoes.keys())).T start = np.random.choice(start) end = np.random.choice(end[end >= start]) token_type_ids[start:end + 1] = 1 subject_ids = (start, end) # 对应的object标签 object_labels = np.zeros( (len(token_ids), len(BAIDU_RELATION), 2), dtype=np.float32) for o in spoes.get(subject_ids, []): object_labels[o[0], o[2], 0] = 1 object_labels[o[1], o[2], 1] = 1 batch_token_ids.append(token_ids) batch_token_type_ids.append(token_type_ids) batch_segment_ids.append(segment_ids) batch_subject_labels.append(subject_labels) batch_subject_ids.append(subject_ids) batch_object_labels.append(object_labels) else: batch_token_ids.append(token_ids) batch_segment_ids.append(segment_ids) batch_token_ids = sequence_padding(batch_token_ids, is_float=False) batch_segment_ids = sequence_padding(batch_segment_ids, is_float=False) if not self.is_train: return p_ids, batch_token_ids, batch_segment_ids else: batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False) batch_subject_ids = torch.tensor(batch_subject_ids) batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2), is_float=True) batch_object_labels = sequence_padding( batch_object_labels, padding=np.zeros((len(BAIDU_RELATION), 2)), is_float=True) return batch_token_ids, batch_segment_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
def collate(examples): p_ids, examples = zip(*examples) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) batch_char_ids, batch_word_ids = [], [] batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], [] for example in examples: # todo maxlen char_ids = [ self.char2idx.get(char, 1) for char in example.context ] word_ids = [ self.word2idx.get(word, 0) for word in example.text_word for _ in word ] if len(char_ids) != len(word_ids): print(example.context) print(char_ids) print(len(char_ids)) print(example.text_word) print(word_ids) print(len(word_ids)) assert len(char_ids) == len(word_ids) char_ids = char_ids[:self.max_len] word_ids = word_ids[:self.max_len] # example.context = example.context[:self.max_len] if self.is_train: spoes = {} for s, p, o in example.gold_answer: s = [self.char2idx.get(s_, 1) for s_ in s] p = BAIDU_RELATION[p] o = [self.char2idx.get(o_, 1) for o_ in o] s_idx = search(s, char_ids) o_idx = search(o, char_ids) if s_idx != -1 and o_idx != -1: s = (s_idx, s_idx + len(s) - 1) o = (o_idx, o_idx + len(o) - 1, p) if s not in spoes: spoes[s] = [] spoes[s].append(o) if spoes: # subject标签 token_type_ids = np.zeros(len(char_ids), dtype=np.long) subject_labels = np.zeros((len(char_ids), 2), dtype=np.float32) for s in spoes: subject_labels[s[0], 0] = 1 subject_labels[s[1], 1] = 1 # 随机选一个subject start, end = np.array(list(spoes.keys())).T start = np.random.choice(start) end = np.random.choice(end[end >= start]) token_type_ids[start:end + 1] = 1 subject_ids = (start, end) # 对应的object标签 object_labels = np.zeros( (len(char_ids), len(BAIDU_RELATION), 2), dtype=np.float32) for o in spoes.get(subject_ids, []): object_labels[o[0], o[2], 0] = 1 object_labels[o[1], o[2], 1] = 1 batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_token_type_ids.append(token_type_ids) batch_subject_labels.append(subject_labels) batch_subject_ids.append(subject_ids) batch_object_labels.append(object_labels) else: batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_char_ids = sequence_padding(batch_char_ids, is_float=False) batch_word_ids = sequence_padding(batch_word_ids, is_float=False) if not self.is_train: return p_ids, batch_char_ids, batch_word_ids else: batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False) batch_subject_ids = torch.tensor(batch_subject_ids) batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2), is_float=True) batch_object_labels = sequence_padding( batch_object_labels, padding=np.zeros((len(BAIDU_RELATION), 2)), is_float=True) return batch_char_ids, batch_word_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
def collate(examples): p_ids, examples = zip(*examples) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) batch_char_ids, batch_word_ids = [], [] batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], [] for example in examples: # todo maxlen char_ids = [ self.char2idx.get(char, 1) for char in example.context ] word_ids = [ self.word2idx.get(word, 0) for word in example.text_word for _ in word ] if len(char_ids) != len(word_ids): print(example.context) print(char_ids) print(len(char_ids)) print(example.text_word) print(word_ids) print(len(word_ids)) assert len(char_ids) == len(word_ids) char_ids = char_ids[:self.max_len] word_ids = word_ids[:self.max_len] # example.context = example.context[:self.max_len] if self.is_train: spoes = {} for s, p, o in example.gold_answer: s = [self.char2idx.get(s_, 1) for s_ in s] # p = BAIDU_RELATION[p] o = [self.char2idx.get(o_, 1) for o_ in o] s_idx = search(s, char_ids) o_idx = search(o, char_ids) if s_idx != -1 and o_idx != -1: s = (s_idx, s_idx + len(s) - 1) o = (o_idx, o_idx + len(o) - 1, p) if s not in spoes: spoes[s] = [] spoes[s].append(o) if spoes: # subject标签 token_type_ids = np.zeros(len(char_ids), dtype=np.long) subject_labels = np.zeros(len(char_ids), dtype=np.int) for s in spoes: subject_labels[s[0]] = BAIDU_ENTITY['B'] for index in range(s[0] + 1, s[1] + 1): subject_labels[index] = BAIDU_ENTITY['I'] # 随机选一个subject subject_ids = random.choice(list(spoes.keys())) token_type_ids[subject_ids[0]:subject_ids[1] + 1] = 1 # 对应的object标签 object_labels = np.zeros(len(char_ids), dtype=np.int) for o in spoes.get(subject_ids, []): object_labels[o[0]] = BAIDU_BIES['B' + '-' + o[2]] for index in range(o[0] + 1, o[1] + 1): object_labels[index] = BAIDU_BIES['I' + '-' + o[2]] batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_token_type_ids.append(token_type_ids) batch_subject_labels.append(subject_labels) batch_subject_ids.append(subject_ids) batch_object_labels.append(object_labels) else: batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_char_ids = sequence_padding(batch_char_ids, is_float=False) batch_word_ids = sequence_padding(batch_word_ids, is_float=False) if not self.is_train: return p_ids, batch_char_ids, batch_word_ids else: batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False) batch_subject_ids = torch.tensor(batch_subject_ids) batch_subject_labels = sequence_padding(batch_subject_labels, is_float=False) batch_object_labels = sequence_padding(batch_object_labels, is_float=False) return batch_char_ids, batch_word_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
def collate(examples): p_ids, examples = zip(*examples) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) batch_char_ids, batch_word_ids = [], [] batch_ent_labels, batch_rel_labels = [], [] for example in examples: # print("example: ", example) # todo maxlen char_ids = [ self.char2idx.get(char, 1) for char in example.context ] # 句子的字序列和词序列,长度不同,为了对齐,一个词中的每个字,都对应了所在词的idx,这样保证了,char_ids和word_ids的长度一样 # 进入模型中,相当与一个词中每个字符都对应了这个词的embedding word_ids = [ self.word2idx.get(word, 0) for word in example.text_word for _ in word ] # word_ids = [self.word2idx.get(word, 0) for word in example.text_word] if len(char_ids) != len(word_ids): print(example.context) print(char_ids) print(len(char_ids)) print(example.text_word) print(word_ids) print(len(word_ids)) assert len(char_ids) == len(word_ids) char_ids = char_ids[:self.max_len] word_ids = word_ids[:self.max_len] example.raw_context = example.context[:self.max_len] if self.is_train: rel_labels = [] bio = ['O'] * len(char_ids) for s, p, o in example.gold_answer: s = [self.char2idx.get(s_, 1) for s_ in s] p = BAIDU_RELATION[p] o = [self.char2idx.get(o_, 1) for o_ in o] s_idx = search(s, char_ids) o_idx = search(o, char_ids) if s_idx != -1 and o_idx != -1: bio[s_idx] = 'B' bio[s_idx + 1:s_idx + len(s)] = 'I' * (len(s) - 1) bio[o_idx] = 'B' bio[o_idx + 1:o_idx + len(o)] = 'I' * (len(o) - 1) s = (s_idx, s_idx + len(s) - 1) o = (o_idx, o_idx + len(o) - 1, p) rel_labels.append((s[1], o[1], o[2])) if rel_labels: ent_labels = np.zeros((len(char_ids)), dtype=np.long) for index, label_ in enumerate(bio): ent_labels[index] = BAIDU_ENTITY[label_] batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_ent_labels.append(ent_labels) batch_rel_labels.append(rel_labels) else: batch_char_ids.append(char_ids) batch_word_ids.append(word_ids) batch_char_ids = sequence_padding(batch_char_ids, is_float=False) batch_word_ids = sequence_padding(batch_word_ids, is_float=False) if not self.is_train: # print("p_ids: ", p_ids) # print("batch_char_ids: ", batch_char_ids) # print("batch_word_ids: ", batch_word_ids) return p_ids, batch_char_ids, batch_word_ids else: batch_ent_labels = sequence_padding(batch_ent_labels, is_float=False) batch_rel_labels = select_padding( batch_char_ids, batch_rel_labels, is_float=True, class_num=len(BAIDU_RELATION)) # print("batch_char_ids: shape=", batch_char_ids.shape, "\n", batch_char_ids) # print("batch_word_ids: shape=", batch_word_ids.shape, "\n", batch_word_ids) # print("batch_ent_labels: shape=",batch_ent_labels.shape, '\n', batch_ent_labels) # print("batch_rel_labels: shape=",batch_rel_labels.shape, "\n", batch_rel_labels) return batch_char_ids, batch_word_ids, batch_ent_labels, batch_rel_labels