示例#1
0
    def __init__(self, train_path, dev_path, max_len):
        self.tokenizer = BasicTokenizer()
        self.train_path = train_path
        self.dev_path = dev_path
        self.max_len = max_len

        self.train_seg_list, self.train_tgt_list, self.train_segment_list, self.train_type_list, self.train_category_list, self.train_a_seg_list, self.train_a_tree_list, self.train_b_seg_list, self.train_b_tree_list = self.load_data(
            train_path)
        self.dev_seg_list, self.dev_tgt_list, self.dev_segment_list, self.dev_type_list, self.dev_category_list, self.dev_a_seg_list, self.dev_a_tree_list, self.dev_b_seg_list, self.dev_b_tree_list = self.load_data(
            dev_path)

        self.train_num, self.dev_num = len(self.train_seg_list), len(
            self.dev_seg_list)
        print('train number is %d, dev number is %d' %
              (self.train_num, self.dev_num))

        num_train_segment, num_dev_segment = len(self.train_segment_list), len(
            self.dev_segment_list)
        num_train_type, num_dev_type = len(self.train_type_list), len(
            self.dev_type_list)
        assert num_train_segment == num_train_type == self.train_num
        assert num_dev_segment == num_dev_type == self.dev_num

        self.train_idx_list, self.dev_idx_list = [
            i for i in range(self.train_num)
        ], [j for j in range(self.dev_num)]
        self.shuffle_train_idx()

        self.train_current_idx = 0
        self.dev_current_idx = 0
示例#2
0
def work_news_char(line):
    "This function only works for news at char level"
    tokenizer = BasicTokenizer()
    line = line.strip()
    if line == "":
        return [[]]
    char_seq = tokenizer.tokenize(line)
    res = []
    sent = []
    for ch in char_seq:
        sent.append(ch)
        if len(sent) >= 20 and _is_split_point(ch):
            res.append(sent)
            sent = []
    if sent:
        if len(sent) <= 3 and len(res) > 0:
            res[-1].extend(sent)
        else:
            res.append(sent)
    return res
示例#3
0
    def __init__(self, filename, vocab, batch_size, for_train):
        tokenizer = BasicTokenizer()
        all_data = [[tokenizer.tokenize(x) for x in line.strip().split('|')]
                    for line in open(filename, encoding='utf8').readlines()]

        self.data = []
        for d in all_data:
            skip = not (len(d) == 4)
            for j, i in enumerate(d):
                if not for_train:
                    d[j] = i[:30]
                    if len(d[j]) == 0:
                        d[j] = [UNK]
                if len(i) == 0 or len(i) > 30:
                    skip = True
            if not (skip and for_train):
                self.data.append(d)

        self.batch_size = batch_size
        self.vocab = vocab
        self.train = for_train
示例#4
0
class DataLoader:
    def __init__(self, train_path, dev_path, tree_vocab, max_len):
        self.tokenizer = BasicTokenizer()
        self.train_path = train_path
        self.dev_path = dev_path
        self.max_len = max_len
        self.tree_vocab = tree_vocab

        self.train_seg_list, self.train_tgt_list, self.train_segment_list, self.train_type_list, self.train_category_list, self.train_a_seg_list, self.train_a_tree_list, self.train_b_seg_list, self.train_b_tree_list = self.load_data(train_path)
        self.dev_seg_list, self.dev_tgt_list, self.dev_segment_list, self.dev_type_list, self.dev_category_list, self.dev_a_seg_list, self.dev_a_tree_list, self.dev_b_seg_list, self.dev_b_tree_list = self.load_data(dev_path)

        self.train_num, self.dev_num = len(self.train_seg_list), len(self.dev_seg_list)
        print ('train number is %d, dev number is %d' % (self.train_num, self.dev_num))

        num_train_segment, num_dev_segment = len(self.train_segment_list), len(self.dev_segment_list)
        num_train_type, num_dev_type = len(self.train_type_list), len(self.dev_type_list)
        assert num_train_segment == num_train_type == self.train_num
        assert num_dev_segment == num_dev_type == self.dev_num

        self.train_idx_list, self.dev_idx_list = [i for i in range(self.train_num)], [j for j in range(self.dev_num)]
        self.shuffle_train_idx()

        self.train_current_idx = 0
        self.dev_current_idx = 0

    def segment(self, text):
        seg = [1 for _ in range(len(text))]
        idx = text.index("sep")
        seg[:idx] = [0 for _ in range(idx)]
        return [0]+seg+[1]  # [CLS]+seg+[SEP]

    def profile(self, text):
        seg = [3 for _ in range(len(text))]
        loc_idx = text.index("loc")
        gender_idx = text.index("gender")
        sep_idx = text.index("sep")
        seg[:loc_idx] = [0 for _ in range(loc_idx)]
        seg[loc_idx:gender_idx] = [1 for _ in range(gender_idx-loc_idx)]
        seg[gender_idx:sep_idx] = [2 for _ in range(sep_idx-gender_idx)]
        return [0]+seg+[3]  # [CLS]+seg+[SEP]

    def read_sentence(self, line):
        indices = self.tree_vocab.convertToIdx(line, Constants.UNK_WORD)
        return torch.LongTensor(indices)

    def read_trees(self, batch):
        trees = [self.read_tree(line) for line in batch]
        return trees

    def read_tree(self, line):
        parents = list(map(int, line.split()))
        trees = dict()
        root = None
        for i in range(1, len(parents) + 1):
            if i - 1 not in trees.keys() and parents[i - 1] != -1:
                idx = i
                prev = None
                while True:
                    parent = parents[idx - 1]
                    if parent == -1:
                        break
                    tree = Tree()
                    if prev is not None:
                        tree.add_child(prev)
                    trees[idx - 1] = tree
                    tree.idx = idx - 1
                    if parent - 1 in trees.keys():
                        trees[parent - 1].add_child(tree)
                        break
                    elif parent == 0:
                        root = tree
                        break
                    else:
                        prev = tree
                        idx = parent
        return root
    
    def data_format(self, src_line):
        '''
        将原始数据格式,转换为模型样本格式
        '''
        line_arr = src_line.strip('\n').split('\t')
        bert_input = line_arr[3].replace(": '", ": ").replace("',", ",").replace("'}", "}")
        bert_input += ' <sep> ' + line_arr[2]
        target = line_arr[5]
        category = line_arr[4]
        a_seg = line_arr[7]
        a_tree = line_arr[8]
        b_seg = line_arr[9]
        b_tree = line_arr[10]

        return bert_input, target, category, a_seg, a_tree, b_seg, b_tree


    def load_data(self, path):
        src_list = list()  # src_list contains segmented text
        tgt_list = list()  # tgt_list contains class number
        seg_list = list()  # seg_list contains 0,1 to indicate profile and response
        typ_list = list()  # typ_list contains 0,1,2,3 to indicate constellation, location, gender and response
        cat_list = list()
        a_seg_list = list()
        a_parse_list = list()
        b_seg_list = list()
        b_parse_list = list()
        with open(path, 'r', encoding = 'utf8') as i:
            lines = i.readlines()
            for l in lines[1:]:
                text, target, category, a_seg, a_tree, b_seg, b_tree = self.data_format(l)
                # content_list = l.strip('\n').split('\t')
                # text = content_list[0]
                target = int(target)
                category = int(category)
                a_seg = self.read_sentence(self.seq_cut(a_seg.split(' ')))
                a_tree = self.read_tree(a_tree)
                b_seg = self.read_sentence(self.seq_cut(b_seg.split(' ')))
                b_tree = self.read_tree(b_tree)
                seg_text = self.tokenizer.tokenize(text)
                post_text = self.seq_cut(seg_text)
                seg_tmp = self.segment(post_text)
                typ_tmp = self.profile(post_text)
                src_list.append(post_text)
                tgt_list.append(target)
                seg_list.append(seg_tmp)
                typ_list.append(typ_tmp)
                cat_list.append(category)

                a_seg_list.append(a_seg)
                a_parse_list.append(a_tree)
                b_seg_list.append(b_seg)
                b_parse_list.append(b_tree)

                assert len(seg_tmp) == len(typ_tmp) == len(post_text)+2

            assert len(src_list) == len(tgt_list) == len(seg_list) == len(typ_list) == len(cat_list)
            assert len(cat_list) == len(a_seg_list) == len(a_parse_list) == len(b_seg_list) == len(b_parse_list)

        return src_list, tgt_list, seg_list, typ_list, cat_list, a_seg_list, a_parse_list, b_seg_list, b_parse_list

    def shuffle_train_idx(self):
        random.shuffle(self.train_idx_list)

    def seq_cut(self, seq):
        if len(seq) > self.max_len:
            seq = seq[ : self.max_len]
        return seq

    def get_next_batch(self, batch_size, mode):
        batch_text_list, batch_label_list = list(), list()
        batch_seg_list, batch_type_list = list(), list()
        batch_category_list = list()
        batch_a_seg_list, batch_a_tree_list = list(), list()
        batch_b_seg_list, batch_b_tree_list = list(), list()
        if mode == 'train':
            if self.train_current_idx + batch_size < self.train_num - 1:
                for i in range(batch_size):
                    curr_idx = self.train_current_idx + i
                    batch_text_list.append(self.train_seg_list[self.train_idx_list[curr_idx]])
                    batch_label_list.append(self.train_tgt_list[self.train_idx_list[curr_idx]])
                    batch_seg_list.append(self.train_segment_list[self.train_idx_list[curr_idx]])
                    batch_type_list.append(self.train_type_list[self.train_idx_list[curr_idx]])
                    batch_category_list.append(self.train_category_list[self.train_idx_list[curr_idx]])
                    batch_a_seg_list.append(self.train_a_seg_list[self.train_idx_list[curr_idx]])
                    batch_a_tree_list.append(self.train_a_tree_list[self.train_idx_list[curr_idx]])
                    batch_b_seg_list.append(self.train_b_seg_list[self.train_idx_list[curr_idx]])
                    batch_b_tree_list.append(self.train_b_tree_list[self.train_idx_list[curr_idx]])
                self.train_current_idx += batch_size
            else:
                for i in range(batch_size):
                    curr_idx = self.train_current_idx + i
                    if curr_idx > self.train_current_idx - 1:
                        self.shuffle_train_idx()
                        curr_idx = 0
                        batch_text_list.append(self.train_seg_list[self.train_idx_list[curr_idx]])
                        batch_label_list.append(self.train_tgt_list[self.train_idx_list[curr_idx]])
                        batch_seg_list.append(self.train_segment_list[self.train_idx_list[curr_idx]])
                        batch_type_list.append(self.train_type_list[self.train_idx_list[curr_idx]])
                        batch_category_list.append(self.train_category_list[self.train_idx_list[curr_idx]])
                        batch_a_seg_list.append(self.train_a_seg_list[self.train_idx_list[curr_idx]])
                        batch_a_tree_list.append(self.train_a_tree_list[self.train_idx_list[curr_idx]])
                        batch_b_seg_list.append(self.train_b_seg_list[self.train_idx_list[curr_idx]])
                        batch_b_tree_list.append(self.train_b_tree_list[self.train_idx_list[curr_idx]])
                    else:
                        batch_text_list.append(self.train_seg_list[self.train_idx_list[curr_idx]])
                        batch_label_list.append(self.train_tgt_list[self.train_idx_list[curr_idx]])
                        batch_seg_list.append(self.train_segment_list[self.train_idx_list[curr_idx]])
                        batch_type_list.append(self.train_type_list[self.train_idx_list[curr_idx]])
                        batch_category_list.append(self.train_category_list[self.train_idx_list[curr_idx]])
                        batch_a_seg_list.append(self.train_a_seg_list[self.train_idx_list[curr_idx]])
                        batch_a_tree_list.append(self.train_a_tree_list[self.train_idx_list[curr_idx]])
                        batch_b_seg_list.append(self.train_b_seg_list[self.train_idx_list[curr_idx]])
                        batch_b_tree_list.append(self.train_b_tree_list[self.train_idx_list[curr_idx]])
                self.train_current_idx = 0

        elif mode == 'dev':
            if self.dev_current_idx + batch_size < self.dev_num - 1:
                for i in range(batch_size):
                    curr_idx = self.dev_current_idx + i
                    batch_text_list.append(self.dev_seg_list[curr_idx])
                    batch_label_list.append(self.dev_tgt_list[curr_idx])
                    batch_seg_list.append(self.dev_segment_list[curr_idx])
                    batch_type_list.append(self.dev_type_list[curr_idx])
                    batch_category_list.append(self.dev_category_list[curr_idx])
                    batch_a_seg_list.append(self.dev_a_seg_list[curr_idx])
                    batch_a_tree_list.append(self.dev_a_tree_list[curr_idx])
                    batch_b_seg_list.append(self.dev_b_seg_list[curr_idx])
                    batch_b_tree_list.append(self.dev_b_tree_list[curr_idx])
                self.dev_current_idx += batch_size
            else:
                for i in range(batch_size):
                    curr_idx = self.dev_current_idx + i
                    if curr_idx > self.dev_num - 1:  # 对dev_current_idx重新赋值
                        curr_idx = 0
                        self.dev_current_idx = 0
                    else:
                        pass
                    batch_text_list.append(self.dev_seg_list[curr_idx])
                    batch_label_list.append(self.dev_tgt_list[curr_idx])
                    batch_seg_list.append(self.dev_segment_list[curr_idx])
                    batch_type_list.append(self.dev_type_list[curr_idx])
                    batch_category_list.append(self.dev_category_list[curr_idx])
                    batch_a_seg_list.append(self.dev_a_seg_list[curr_idx])
                    batch_a_tree_list.append(self.dev_a_tree_list[curr_idx])
                    batch_b_seg_list.append(self.dev_b_seg_list[curr_idx])
                    batch_b_tree_list.append(self.dev_b_tree_list[curr_idx])
                self.dev_current_idx = 0
        else:
            raise Exception('Wrong batch mode!!!')

        return batch_text_list, batch_label_list, batch_seg_list, batch_type_list, batch_category_list, batch_a_seg_list, batch_a_tree_list, batch_b_seg_list, batch_b_tree_list
示例#5
0
    pred_output = model(batch_text_list, batch_seg_list, batch_type_list, batch_a_seg_list, batch_a_tree_list, batch_b_seg_list, batch_b_tree_list, fine_tune=True)
    logits = pred_output[0]
    loss = criterion(logits.view(-1, label_nums), batch_label_ids.view(-1))
    return logits, loss


if __name__ == '__main__':
    args = parse_config()
    ckpt_path = args.ckpt_path
    test_data = args.test_data
    out_path = args.out_path
    gpu_id = args.gpu_id

    model, tree_vocab = init_model(ckpt_path)
    model.cuda(gpu_id)
    tokenizer = BasicTokenizer()

    if args.do_train:
        train_path = 'data/KvPI_train.txt'
        dev_path = 'data/KvPI_valid.txt'
        data_loader = DataLoader(train_path, dev_path, tree_vocab, args.max_len)
        criterion = CrossEntropyLoss()
        label_nums = model.num_class
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        optimizer = optim.AdamW(model.parameters(), lr=3e-5)

        # param_optimizer = list(model.named_parameters())
        # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        # optimizer_grouped_parameters = [
        #     {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        #     {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}