예제 #1
0
        def collate(examples):
            p_ids, examples = zip(*examples)
            p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
            batch_token_ids, batch_segment_ids = [], []
            batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
            for example in examples:
                spoes = example.spoes
                token_ids = self.tokenizer.encode(example.bert_tokens)[1:-1]
                segment_ids = len(token_ids) * [0]

                if self.is_train:

                    if spoes:
                        # subject标签
                        token_type_ids = np.zeros(len(token_ids),
                                                  dtype=np.long)
                        subject_labels = np.zeros((len(token_ids), 2),
                                                  dtype=np.float32)
                        for s in spoes:
                            subject_labels[s[0], 0] = 1
                            subject_labels[s[1], 1] = 1
                        # 随机选一个subject
                        subject_ids = random.choice(list(spoes.keys()))
                        # 对应的object标签
                        object_labels = np.zeros(
                            (len(token_ids), len(self.spo_config), 2),
                            dtype=np.float32)
                        for o in spoes.get(subject_ids, []):
                            object_labels[o[0], o[2], 0] = 1
                            object_labels[o[1], o[2], 1] = 1
                        batch_token_ids.append(token_ids)
                        batch_token_type_ids.append(token_type_ids)

                        batch_segment_ids.append(segment_ids)
                        batch_subject_labels.append(subject_labels)
                        batch_subject_ids.append(subject_ids)
                        batch_object_labels.append(object_labels)
                else:
                    batch_token_ids.append(token_ids)
                    batch_segment_ids.append(segment_ids)

            batch_token_ids = sequence_padding(batch_token_ids, is_float=False)
            batch_segment_ids = sequence_padding(batch_segment_ids,
                                                 is_float=False)
            if not self.is_train:
                return p_ids, batch_token_ids, batch_segment_ids
            else:
                batch_token_type_ids = sequence_padding(batch_token_type_ids,
                                                        is_float=False)
                batch_subject_ids = torch.tensor(batch_subject_ids)
                batch_subject_labels = sequence_padding(batch_subject_labels,
                                                        padding=np.zeros(2),
                                                        is_float=True)
                batch_object_labels = sequence_padding(
                    batch_object_labels,
                    padding=np.zeros((len(self.spo_config), 2)),
                    is_float=True)
                return batch_token_ids, batch_segment_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
예제 #2
0
        def collate(examples):
            p_ids, examples = zip(*examples)
            p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
            batch_char_ids, batch_word_ids = [], []
            batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
            for example in examples:
                # todo maxlen
                char_ids = [self.char2idx.get(char, 1) for char in example.context]
                word_ids = [self.word2idx.get(word, 0) for word in example.text_word for _ in word]
                if len(char_ids) != len(word_ids):
                    print(example.context)
                    print(char_ids)
                    print(len(char_ids))
                    print(example.text_word)
                    print(word_ids)
                    print(len(word_ids))
                assert len(char_ids) == len(word_ids)
                char_ids = char_ids[:self.max_len]
                word_ids = word_ids[:self.max_len]
                # example.context = example.context[:self.max_len]

                if self.is_train:

                    subject_labels = np.zeros((len(char_ids), 2), dtype=np.float32)
                    token_type_ids = np.zeros(len(char_ids), dtype=np.long)
                    object_labels = np.zeros((len(char_ids), len(DRUG_RELATION), 2), dtype=np.float32)
                    for s in example.ent_list:
                        subject_labels[s[0], 0] = 1
                        subject_labels[s[1] - 1, 1] = 1
                    if example.sub_entity_list:
                        subject_ids = example.sub_entity_list
                        token_type_ids[subject_ids[0]:subject_ids[1] + 1] = 1
                        # 对应的object标签
                        for o in example.spoes.get(subject_ids, []):
                            object_labels[o[0], o[2], 0] = 1
                            object_labels[o[1], o[2], 1] = 1
                        batch_char_ids.append(char_ids)
                        batch_word_ids.append(word_ids)
                        batch_token_type_ids.append(token_type_ids)
                        batch_subject_labels.append(subject_labels)
                        batch_subject_ids.append(subject_ids)
                        batch_object_labels.append(object_labels)
                else:
                    batch_char_ids.append(char_ids)
                    batch_word_ids.append(word_ids)

            batch_char_ids = sequence_padding(batch_char_ids, is_float=False)
            batch_word_ids = sequence_padding(batch_word_ids, is_float=False)
            if not self.is_train:
                return p_ids, batch_char_ids, batch_word_ids
            else:
                batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False)
                batch_subject_ids = torch.tensor(batch_subject_ids)
                batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2), is_float=True)
                batch_object_labels = sequence_padding(batch_object_labels, padding=np.zeros((len(DRUG_RELATION), 2)),
                                                       is_float=True)
                return batch_char_ids, batch_word_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
예제 #3
0
def text2id(args, char2idx, raw_text: list):
    p_ids = []
    batch_char_ids = []
    examples = []
    for index, text_ in enumerate(raw_text):
        p_ids.append(index)
        char_ids = [char2idx.get(char, 1) for char in text_]
        batch_char_ids.append(char_ids)
        examples.append(Example(
            p_id=index,
            context=text_,
        ))
    p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
    text_ids = sequence_padding(batch_char_ids, is_float=False)
    return examples, (p_ids, text_ids)
예제 #4
0
        def collate(examples):
            p_ids, examples = zip(*examples)
            p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
            batch_token_ids, batch_segment_ids = [], []
            batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
            for example in examples:
                # todo maxlen
                token_ids, segment_ids = self.tokenizer.encode(
                    example.context, max_length=self.max_len)
                example.bert_tokens = self.tokenizer.tokenize(example.context)
                example.token_ids = token_ids
                if self.is_train:
                    spoes = {}
                    for s, p, o in example.gold_answer:
                        s = self.tokenizer.encode(s)[0][1:-1]
                        p = BAIDU_RELATION[p]
                        o = self.tokenizer.encode(o)[0][1:-1]
                        s_idx = search(s, token_ids)
                        o_idx = search(o, token_ids)
                        if s_idx != -1 and o_idx != -1:
                            s = (s_idx, s_idx + len(s) - 1)
                            o = (o_idx, o_idx + len(o) - 1, p)
                            if s not in spoes:
                                spoes[s] = []
                            spoes[s].append(o)

                    if spoes:
                        # subject标签
                        token_type_ids = np.zeros(len(token_ids),
                                                  dtype=np.long)
                        subject_labels = np.zeros((len(token_ids), 2),
                                                  dtype=np.float32)
                        for s in spoes:
                            subject_labels[s[0], 0] = 1
                            subject_labels[s[1], 1] = 1
                        # 随机选一个subject
                        start, end = np.array(list(spoes.keys())).T
                        start = np.random.choice(start)
                        end = np.random.choice(end[end >= start])
                        token_type_ids[start:end + 1] = 1
                        subject_ids = (start, end)
                        # 对应的object标签
                        object_labels = np.zeros(
                            (len(token_ids), len(BAIDU_RELATION), 2),
                            dtype=np.float32)
                        for o in spoes.get(subject_ids, []):
                            object_labels[o[0], o[2], 0] = 1
                            object_labels[o[1], o[2], 1] = 1
                        batch_token_ids.append(token_ids)
                        batch_token_type_ids.append(token_type_ids)

                        batch_segment_ids.append(segment_ids)
                        batch_subject_labels.append(subject_labels)
                        batch_subject_ids.append(subject_ids)
                        batch_object_labels.append(object_labels)
                else:
                    batch_token_ids.append(token_ids)
                    batch_segment_ids.append(segment_ids)

            batch_token_ids = sequence_padding(batch_token_ids, is_float=False)
            batch_segment_ids = sequence_padding(batch_segment_ids,
                                                 is_float=False)
            if not self.is_train:
                return p_ids, batch_token_ids, batch_segment_ids
            else:
                batch_token_type_ids = sequence_padding(batch_token_type_ids,
                                                        is_float=False)
                batch_subject_ids = torch.tensor(batch_subject_ids)
                batch_subject_labels = sequence_padding(batch_subject_labels,
                                                        padding=np.zeros(2),
                                                        is_float=True)
                batch_object_labels = sequence_padding(
                    batch_object_labels,
                    padding=np.zeros((len(BAIDU_RELATION), 2)),
                    is_float=True)
                return batch_token_ids, batch_segment_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
예제 #5
0
        def collate(examples):
            p_ids, examples = zip(*examples)
            p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
            batch_char_ids, batch_word_ids = [], []
            batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
            for example in examples:
                # todo maxlen
                char_ids = [
                    self.char2idx.get(char, 1) for char in example.context
                ]
                word_ids = [
                    self.word2idx.get(word, 0) for word in example.text_word
                    for _ in word
                ]
                if len(char_ids) != len(word_ids):
                    print(example.context)
                    print(char_ids)
                    print(len(char_ids))
                    print(example.text_word)
                    print(word_ids)
                    print(len(word_ids))
                assert len(char_ids) == len(word_ids)
                char_ids = char_ids[:self.max_len]
                word_ids = word_ids[:self.max_len]
                # example.context = example.context[:self.max_len]

                if self.is_train:
                    spoes = {}
                    for s, p, o in example.gold_answer:
                        s = [self.char2idx.get(s_, 1) for s_ in s]
                        p = BAIDU_RELATION[p]
                        o = [self.char2idx.get(o_, 1) for o_ in o]
                        s_idx = search(s, char_ids)
                        o_idx = search(o, char_ids)
                        if s_idx != -1 and o_idx != -1:
                            s = (s_idx, s_idx + len(s) - 1)
                            o = (o_idx, o_idx + len(o) - 1, p)
                            if s not in spoes:
                                spoes[s] = []
                            spoes[s].append(o)

                    if spoes:
                        # subject标签
                        token_type_ids = np.zeros(len(char_ids), dtype=np.long)
                        subject_labels = np.zeros((len(char_ids), 2),
                                                  dtype=np.float32)
                        for s in spoes:
                            subject_labels[s[0], 0] = 1
                            subject_labels[s[1], 1] = 1
                        # 随机选一个subject
                        start, end = np.array(list(spoes.keys())).T
                        start = np.random.choice(start)
                        end = np.random.choice(end[end >= start])
                        token_type_ids[start:end + 1] = 1
                        subject_ids = (start, end)
                        # 对应的object标签
                        object_labels = np.zeros(
                            (len(char_ids), len(BAIDU_RELATION), 2),
                            dtype=np.float32)
                        for o in spoes.get(subject_ids, []):
                            object_labels[o[0], o[2], 0] = 1
                            object_labels[o[1], o[2], 1] = 1
                        batch_char_ids.append(char_ids)
                        batch_word_ids.append(word_ids)
                        batch_token_type_ids.append(token_type_ids)
                        batch_subject_labels.append(subject_labels)
                        batch_subject_ids.append(subject_ids)
                        batch_object_labels.append(object_labels)
                else:
                    batch_char_ids.append(char_ids)
                    batch_word_ids.append(word_ids)

            batch_char_ids = sequence_padding(batch_char_ids, is_float=False)
            batch_word_ids = sequence_padding(batch_word_ids, is_float=False)
            if not self.is_train:
                return p_ids, batch_char_ids, batch_word_ids
            else:
                batch_token_type_ids = sequence_padding(batch_token_type_ids,
                                                        is_float=False)
                batch_subject_ids = torch.tensor(batch_subject_ids)
                batch_subject_labels = sequence_padding(batch_subject_labels,
                                                        padding=np.zeros(2),
                                                        is_float=True)
                batch_object_labels = sequence_padding(
                    batch_object_labels,
                    padding=np.zeros((len(BAIDU_RELATION), 2)),
                    is_float=True)
                return batch_char_ids, batch_word_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
예제 #6
0
        def collate(examples):
            p_ids, examples = zip(*examples)
            p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
            batch_char_ids, batch_word_ids = [], []
            batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
            for example in examples:
                # todo maxlen
                char_ids = [
                    self.char2idx.get(char, 1) for char in example.context
                ]
                word_ids = [
                    self.word2idx.get(word, 0) for word in example.text_word
                    for _ in word
                ]
                if len(char_ids) != len(word_ids):
                    print(example.context)
                    print(char_ids)
                    print(len(char_ids))
                    print(example.text_word)
                    print(word_ids)
                    print(len(word_ids))
                assert len(char_ids) == len(word_ids)
                char_ids = char_ids[:self.max_len]
                word_ids = word_ids[:self.max_len]
                # example.context = example.context[:self.max_len]

                if self.is_train:
                    spoes = {}
                    for s, p, o in example.gold_answer:
                        s = [self.char2idx.get(s_, 1) for s_ in s]
                        # p = BAIDU_RELATION[p]
                        o = [self.char2idx.get(o_, 1) for o_ in o]
                        s_idx = search(s, char_ids)
                        o_idx = search(o, char_ids)
                        if s_idx != -1 and o_idx != -1:
                            s = (s_idx, s_idx + len(s) - 1)
                            o = (o_idx, o_idx + len(o) - 1, p)
                            if s not in spoes:
                                spoes[s] = []
                            spoes[s].append(o)

                    if spoes:
                        # subject标签
                        token_type_ids = np.zeros(len(char_ids), dtype=np.long)
                        subject_labels = np.zeros(len(char_ids), dtype=np.int)
                        for s in spoes:
                            subject_labels[s[0]] = BAIDU_ENTITY['B']
                            for index in range(s[0] + 1, s[1] + 1):
                                subject_labels[index] = BAIDU_ENTITY['I']
                        # 随机选一个subject
                        subject_ids = random.choice(list(spoes.keys()))

                        token_type_ids[subject_ids[0]:subject_ids[1] + 1] = 1
                        # 对应的object标签
                        object_labels = np.zeros(len(char_ids), dtype=np.int)
                        for o in spoes.get(subject_ids, []):
                            object_labels[o[0]] = BAIDU_BIES['B' + '-' + o[2]]
                            for index in range(o[0] + 1, o[1] + 1):
                                object_labels[index] = BAIDU_BIES['I' + '-' +
                                                                  o[2]]
                        batch_char_ids.append(char_ids)
                        batch_word_ids.append(word_ids)
                        batch_token_type_ids.append(token_type_ids)
                        batch_subject_labels.append(subject_labels)
                        batch_subject_ids.append(subject_ids)
                        batch_object_labels.append(object_labels)
                else:
                    batch_char_ids.append(char_ids)
                    batch_word_ids.append(word_ids)

            batch_char_ids = sequence_padding(batch_char_ids, is_float=False)
            batch_word_ids = sequence_padding(batch_word_ids, is_float=False)
            if not self.is_train:
                return p_ids, batch_char_ids, batch_word_ids
            else:
                batch_token_type_ids = sequence_padding(batch_token_type_ids,
                                                        is_float=False)
                batch_subject_ids = torch.tensor(batch_subject_ids)
                batch_subject_labels = sequence_padding(batch_subject_labels,
                                                        is_float=False)
                batch_object_labels = sequence_padding(batch_object_labels,
                                                       is_float=False)
                return batch_char_ids, batch_word_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
예제 #7
0
        def collate(examples):
            p_ids, examples = zip(*examples)
            p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
            batch_char_ids, batch_word_ids = [], []
            batch_ent_labels, batch_rel_labels = [], []
            for example in examples:
                # print("example: ", example)
                # todo maxlen
                char_ids = [
                    self.char2idx.get(char, 1) for char in example.context
                ]
                # 句子的字序列和词序列,长度不同,为了对齐,一个词中的每个字,都对应了所在词的idx,这样保证了,char_ids和word_ids的长度一样
                # 进入模型中,相当与一个词中每个字符都对应了这个词的embedding
                word_ids = [
                    self.word2idx.get(word, 0) for word in example.text_word
                    for _ in word
                ]
                # word_ids = [self.word2idx.get(word, 0) for word in example.text_word]
                if len(char_ids) != len(word_ids):
                    print(example.context)
                    print(char_ids)
                    print(len(char_ids))
                    print(example.text_word)
                    print(word_ids)
                    print(len(word_ids))
                assert len(char_ids) == len(word_ids)
                char_ids = char_ids[:self.max_len]
                word_ids = word_ids[:self.max_len]
                example.raw_context = example.context[:self.max_len]

                if self.is_train:
                    rel_labels = []
                    bio = ['O'] * len(char_ids)
                    for s, p, o in example.gold_answer:
                        s = [self.char2idx.get(s_, 1) for s_ in s]
                        p = BAIDU_RELATION[p]
                        o = [self.char2idx.get(o_, 1) for o_ in o]
                        s_idx = search(s, char_ids)
                        o_idx = search(o, char_ids)
                        if s_idx != -1 and o_idx != -1:
                            bio[s_idx] = 'B'
                            bio[s_idx + 1:s_idx + len(s)] = 'I' * (len(s) - 1)
                            bio[o_idx] = 'B'
                            bio[o_idx + 1:o_idx + len(o)] = 'I' * (len(o) - 1)
                            s = (s_idx, s_idx + len(s) - 1)
                            o = (o_idx, o_idx + len(o) - 1, p)
                            rel_labels.append((s[1], o[1], o[2]))

                    if rel_labels:
                        ent_labels = np.zeros((len(char_ids)), dtype=np.long)
                        for index, label_ in enumerate(bio):
                            ent_labels[index] = BAIDU_ENTITY[label_]
                        batch_char_ids.append(char_ids)
                        batch_word_ids.append(word_ids)
                        batch_ent_labels.append(ent_labels)
                        batch_rel_labels.append(rel_labels)
                else:
                    batch_char_ids.append(char_ids)
                    batch_word_ids.append(word_ids)

            batch_char_ids = sequence_padding(batch_char_ids, is_float=False)
            batch_word_ids = sequence_padding(batch_word_ids, is_float=False)
            if not self.is_train:
                # print("p_ids: ", p_ids)
                # print("batch_char_ids: ", batch_char_ids)
                # print("batch_word_ids: ", batch_word_ids)
                return p_ids, batch_char_ids, batch_word_ids
            else:
                batch_ent_labels = sequence_padding(batch_ent_labels,
                                                    is_float=False)
                batch_rel_labels = select_padding(
                    batch_char_ids,
                    batch_rel_labels,
                    is_float=True,
                    class_num=len(BAIDU_RELATION))
                # print("batch_char_ids: shape=", batch_char_ids.shape, "\n", batch_char_ids)
                # print("batch_word_ids: shape=", batch_word_ids.shape, "\n", batch_word_ids)
                # print("batch_ent_labels: shape=",batch_ent_labels.shape, '\n', batch_ent_labels)
                # print("batch_rel_labels: shape=",batch_rel_labels.shape, "\n", batch_rel_labels)
                return batch_char_ids, batch_word_ids, batch_ent_labels, batch_rel_labels