def get_dataloader(data_dir, data_file, num_choices, tokenizer, max_seq_length, batch_size, dg): examples = read_swag_examples(os.path.join(data_dir, data_file), max_pad_length=num_choices + 2, dg=dg) features = convert_examples_to_features(examples, tokenizer, max_seq_length) all_input_ids = torch.LongTensor(select_field(features, 'input_ids')) all_input_mask = torch.LongTensor(select_field(features, 'input_mask')) all_segment_ids = torch.LongTensor(select_field(features, 'segment_ids')) all_doc_len = torch.LongTensor(select_field(features, 'doc_len')) all_ques_len = torch.LongTensor(select_field(features, 'ques_len')) all_option_len = torch.LongTensor(select_field(features, 'option_len')) all_label = torch.LongTensor([f.label for f in features]) data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label, all_doc_len, all_ques_len, all_option_len) sampler = SequentialSampler(data) dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) return dataloader, len(examples)
def build_dataset_eval(config): abbrs_path = './data/abbrs-all-uncased.pkl' txt_path = './data/test(2030).txt' with open(abbrs_path, 'rb') as f: abbrs = pickle.load(f) src_txt, tar_1_txt, tar_2_txt = get_test_src_tar_txt( txt_path, config.tokenizer) seq_srcs = [] dcmn_srcs = [] dcmn_labels = [] key_choices = [] with open('./data/test_mask_step2_2030.pkl', 'rb') as f: mask_step1 = pickle.load(f) k_as = [] for i, (src, tar) in enumerate(zip(src_txt, tar_1_txt)): src = word_tokenize(src, config.tokenizer) tar = word_tokenize(tar, config.tokenizer) sentences, labels, _src, key_ans, _, _tar = get_dcmn_data_from_gt( src, tar, abbrs, max_pad_length=config.num_choices + 2, max_dcmn_seq_length=config.max_seq_length, tokenizer=config.tokenizer) k_as.append(key_ans) for i, (sts, masks, k_a) in enumerate(zip(src_txt, mask_step1, k_as)): sts = word_tokenize(sts, config.tokenizer) assert len(sts) == len(masks) sentences, labels, _src, k_cs = get_dcmn_data_from_step1( sts, masks, k_a, abbrs, max_pad_length=config.num_choices + 2, max_dcmn_seq_length=config.max_seq_length, tokenizer=config.tokenizer) dcmn_srcs.extend(sentences) dcmn_labels.extend(labels) if len(sentences) != _src.count('[UNK]'): print(i, sts) seq_srcs.append(_src) key_choices.append(k_cs) for i in range(len(seq_srcs)): seq_srcs[i] = '[CLS] ' + seq_srcs[i] + ' [SEP]' cudics = pickle.load(open('./data/test_cudics.pkl', 'rb')) seq_tars = pickle.load(open('./data/test_tars.pkl', 'rb')) q_id = [i + 1 for i in range(len(dcmn_labels))] article = [u[0] for u in dcmn_srcs] question = [u[1] for u in dcmn_srcs] cts = [] for i in range(config.num_choices): cts.append([u[i + 2] for u in dcmn_srcs]) examples = [ SwagExample( swag_id=s5, context_sentence=s1, start_ending=s2, endings=s3, label=s4, ) for i, ( s1, s2, *s3, s4, s5) in enumerate(zip(article, question, *cts, dcmn_labels, q_id)) ] features = convert_examples_to_features(examples, config.tokenizer, config.max_seq_length) input_ids = select_field(features, 'input_ids') input_mask = select_field(features, 'input_mask') segment_ids = select_field(features, 'segment_ids') doc_len = select_field(features, 'doc_len') ques_len = select_field(features, 'ques_len') option_len = select_field(features, 'option_len') labels = [f.label for f in features] dcmn_contents = [] for i in range(len(input_ids)): dcmn_contents.append( (input_ids[i], input_mask[i], segment_ids[i], doc_len[i], ques_len[i], option_len[i], labels[i])) seq_contents = [] for i in range(len(seq_srcs)): seq_contents.append( (seq_srcs[i], seq_tars[i], cudics[i], key_choices[i])) return seq_contents, dcmn_contents
def build_dataset(config): abbrs_path = './data/abbrs-all-cased.pkl' # txt_path = './data/train(12809).txt' txt_path = os.path.join(config.data_dir, config.train_file) with open(abbrs_path, 'rb') as f: abbrs = pickle.load(f) src_txt, tar_1_txt, tar_2_txt = get_train_src_tar_txt(txt_path) # src_txt = src_txt[:100] # tar_1_txt = tar_1_txt[:100] # tar_2_txt = tar_2_txt[:100] seq_srcs = [] seq_tars = [] dcmn_srcs = [] dcmn_labels = [] key_choices = [] for i, (src, tar) in enumerate(zip(src_txt, tar_1_txt)): src = word_tokenize(src, config.tokenizer) tar = word_tokenize(tar, config.tokenizer) sentences, labels, _src, key_ans, k_c, _tar = get_dcmn_data_from_gt( src, tar, abbrs, max_pad_length=config.num_choices + 2, max_dcmn_seq_length=config.max_seq_length, tokenizer=config.tokenizer) if len(sentences) != _src.count('[UNK]'): print(i, src, len(sentences)) dcmn_srcs.extend(sentences) dcmn_labels.extend(labels) seq_srcs.append(_src) seq_tars.append(_tar) key_choices.append(k_c) for i in range(len(seq_srcs)): seq_srcs[i] = '[CLS] ' + seq_srcs[i] + ' [SEP]' q_id = [i + 1 for i in range(len(dcmn_labels))] article = [u[0] for u in dcmn_srcs] question = [u[1] for u in dcmn_srcs] cts = [] for i in range(config.num_choices): cts.append([u[i + 2] for u in dcmn_srcs]) examples = [ SwagExample( swag_id=s5, context_sentence=s1, start_ending=s2, endings=s3, label=s4, ) for i, ( s1, s2, *s3, s4, s5) in enumerate(zip(article, question, *cts, dcmn_labels, q_id)) ] features = convert_examples_to_features(examples, config.tokenizer, config.max_seq_length) input_ids = select_field(features, 'input_ids') input_mask = select_field(features, 'input_mask') segment_ids = select_field(features, 'segment_ids') doc_len = select_field(features, 'doc_len') ques_len = select_field(features, 'ques_len') option_len = select_field(features, 'option_len') labels = [f.label for f in features] dcmn_contents = [] for i in range(len(input_ids)): dcmn_contents.append( (input_ids[i], input_mask[i], segment_ids[i], doc_len[i], ques_len[i], option_len[i], labels[i])) seq_contents = [] for i in range(len(seq_srcs)): seq_contents.append((seq_srcs[i], seq_tars[i], key_choices[i])) return seq_contents, dcmn_contents