def setUp(self): self.strings = ["ab", "c", "def", "ghij"] self.weights = [4.0, 2.0, 7.0, 1.5] self.size_ratio = 2 self.dataset = ListDataset( self.strings, np.array([len(s) for s in self.strings]) )
def prepare_tokens(self, tokens: torch.Tensor): sizes = [len(seq) for seq in tokens] src_tokens = ListDataset(tokens, sizes=sizes) src_tokens = RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()) word_masks_w_bos = WordEndMaskDataset(src_tokens, self.dictionary, self.is_word_initial, bos_value=1, eos_value=0) dataset = { "id": IdDataset(), "net_input": { "src_tokens": src_tokens, "nsrc_tokens": NumelDataset(src_tokens), "word_mask_w_bos": RightPadDataset(word_masks_w_bos, pad_idx=0), }, "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), "nsentences": NumSamplesDataset(), } dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) return dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): if constraints is not None: raise NotImplementedError( "Constrained decoding with the multilingual_translation task is not supported" ) src_data = ListDataset(src_tokens, src_lengths) dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] if self.args.lang_tok_replacing_bos_eos: dataset = self.data_manager.alter_dataset_langtok( dataset, src_eos=self.source_dictionary.eos(), src_lang=self.args.source_lang, tgt_eos=self.target_dictionary.eos(), tgt_lang=self.args.target_lang, src_langtok_spec=src_langtok_spec, tgt_langtok_spec=tgt_langtok_spec, ) else: dataset.src = self.data_manager.src_dataset_tranform_func( self.args.source_lang, self.args.target_lang, dataset=dataset.src, spec=src_langtok_spec, ) return dataset
def _get_epoch_batch_itr(ref, bsz, skip_remainder_batch): dsz = len(ref) indices = range(dsz) starts = indices[::bsz] batch_sampler = [indices[s:s + bsz] for s in starts] dataset = ListDataset(ref) itr = iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, skip_remainder_batch=skip_remainder_batch, ) return itr.next_epoch_itr()
def _set_up_train_dataset(self, split_range) -> torch.utils.data.Dataset: new_date_ranges = _date_list_from_arg(self.args.new_data_date_range) logger.info( f'Setting up training data: {split_range}, {new_date_ranges}') new_hive_data = HiveDataset( table=self.args.table, namespace=self.args.namespace, limit=self.args.query_limit, date_ranges=new_date_ranges, filter_fn=lambda x: _should_include(x[0], split_range), ) desired_total_data_size = self.args.old_to_new_ratio * len( new_hive_data) desired_old_data_size = ( 1 - (1 / self.args.old_to_new_ratio)) * desired_total_data_size old_date_ranges = _date_list_from_arg(self.args.train_date_range) old_hive_data = HiveDataset( table=self.args.table, namespace=self.args.namespace, limit=min(self.args.query_limit, desired_old_data_size), date_ranges=old_date_ranges, filter_fn=lambda x: _should_include(x[0], split_range), ) old_hive_data = old_hive_data[:int(desired_old_data_size)] all_data = new_hive_data.data + list(old_hive_data) conversations = ConversationDataset( dataset=ListDataset(dataset=_shuffle(all_data)), dictionary=self.dictionary, split_range=split_range, ) logger.info( f"Created train dataset of size: {len(conversations)} conversations" ) return conversations
def build_dataset_for_inference(self, src_tokens, src_lengths): src_data = ListDataset(src_tokens, src_lengths) dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main'] if self.args.lang_tok_replacing_bos_eos: dataset = self.data_manager.alter_dataset_langtok( dataset, src_eos=self.source_dictionary.eos(), src_lang=self.args.source_lang, tgt_eos=self.target_dictionary.eos(), tgt_lang=self.args.target_lang, src_langtok_spec=src_langtok_spec, tgt_langtok_spec=tgt_langtok_spec, ) else: dataset.src = self.data_manager.src_dataset_tranform_func( self.args.source_lang, self.args.target_lang, dataset=dataset.src, spec=src_langtok_spec, ) return dataset
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print("Split type --> " + str(split)) def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] labels = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if 'answerKey' in example: label = ord(example['answerKey']) - ord('A') labels.append(label) question = example['question']['stem'] if(self.args.num_classes != len(example['question']['choices'])): print("Class size = " + str(self.args.num_classes) + ". Length of sample size = " + str(len(example['question']['choices']))) assert len(example['question']['choices']) == self.args.num_classes # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` question = 'Q: ' + question question_toks = binarize(question, append_bos=True) for i, choice in enumerate(example['question']['choices']): src = 'A: ' + choice['text'] src_bin = torch.cat([question_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s: str, append_eos: bool = False): if self.tokenizer is not None: s = self.tokenizer.encode(s) if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=append_eos, add_if_not_exist=False, ).long() if self.args.init_token is not None: tokens = torch.cat( [tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator( data_path): prefix = sentence[:pronoun_span.start].text suffix = sentence[pronoun_span.end:].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE leading_space = ' ' if sentence[:pronoun_span. start].text_with_ws.endswith( ' ') else '' trailing_space = ' ' if pronoun_span.text_with_ws.endswith( ' ') else '' # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( wsc_utils.extended_noun_chunks(sentence), exclude_pronouns=True, exclude_query=query, exact_match=False, ) def binarize_with_mask(txt): toks = binarize( prefix + leading_space + txt + trailing_space + suffix, append_eos=True, ) mask = torch.zeros_like(toks, dtype=torch.uint8) mask_start = len(binarize(prefix)) mask_size = len(binarize(leading_space + txt)) mask[mask_start:mask_start + mask_size] = 1 return toks, mask if query is not None: query_toks, query_mask = binarize_with_mask(query) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = binarize_with_mask(cand_span.text) cand_toks.append(toks) cand_masks.append(mask) # collate candidates cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) assert cand_toks.size() == cand_masks.size() candidate_tokens.append(cand_toks) candidate_masks.append(cand_masks) candidate_lengths.append(cand_toks.size(1)) labels.append(label) query_lengths = np.array(query_lengths) query_tokens = ListDataset(query_tokens, query_lengths) query_masks = ListDataset(query_masks, query_lengths) candidate_lengths = np.array(candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) labels = ListDataset(labels, [1] * len(labels)) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'labels': labels, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def getIns(bped,bpeTokens,tokens,L,R): resL=0 tkL=" ".join(tokens[:L]) bped_tkL=self.bpe.encode(tkL) if bped.find(bped_tkL)==0: resL=len(bped_tkL.split()) else: tkL+=" " bped_tkL=self.bpe.encode(tkL) if bped.find(bped_tkL)==0: resL=len(bped_tkL.split()) resR=0 tkR=" ".join(tokens[R:]) bped_tkR=self.bpe.encode(tkR) if bped.rfind(bped_tkR)+len(bped_tkR)==len(bped): resR=len(bpeTokens)-len(bped_tkR.split()) else: tkR=" "+tkR bped_tkR=self.bpe.encode(tkR) if bped.rfind(bped_tkR)+len(bped_tkR)==len(bped): resR=len(bpeTokens)-len(bped_tkR.split()) return resL, resR def getExample(a,bias): s=" ".join(a["token"]) ss=self.bpe.encode(s) sst=ss.split() headL=a['h']['pos'][0] headR=a['h']['pos'][1] hiL, hiR=getIns(ss,sst,a["token"],headL,headR) tailL=a['t']['pos'][0] tailR=a['t']['pos'][1] tiL, tiR=getIns(ss,sst,a["token"],tailL,tailR) E1b='1' E1e='2' E2b='3' E2e='4' ins=[(hiL, E1b), (hiR, E1e), (tiL, E2b), (tiR, E2e)] ins=sorted(ins) pE1=0 pE2=0 pE1_=0 pE2_=0 for i in range(0,4): sst.insert(ins[i][0]+i,ins[i][1]) if ins[i][1]==E1b: pE1=ins[i][0]+i elif ins[i][1]==E2b: pE2=ins[i][0]+i elif ins[i][1]==E1e: pE1_=ins[i][0]+i else: pE2_=ins[i][0]+i if pE1_-pE1==1 or pE2_-pE2==1: return "???", -1, -1 else: return " ".join(sst), pE1+bias, pE2+bias def get_example_bert(item): if 'text' in item: sentence = item['text'] is_token = False else: sentence = item['token'] is_token = True pos_head = item['h']['pos'] pos_tail = item['t']['pos'] pos_min = pos_head pos_max = pos_tail if pos_head[0] > pos_tail[0]: pos_min = pos_tail pos_max = pos_head rev = True else: rev = False if not is_token: sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) else: sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]])) ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]])) sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]])) ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]])) sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:])) ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]'] ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]'] re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]'] pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1) pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0) #pos1 = min(self.max_length - 1, pos1) #pos2 = min(self.max_length - 1, pos2) indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) avai_len = len(indexed_tokens) # Position #pos1 = torch.tensor([[pos1]]).long() #pos2 = torch.tensor([[pos2]]).long() #indexed_tokens = indexed_tokens[:self.max_length] indexed_tokens = torch.tensor(indexed_tokens).long() return indexed_tokens, pos1, pos2 def binarize(s, append_bos=False): #if self.bpe is not None: # s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') rel2id_path=os.path.join(self.args.data, "rel2id.json") if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) if not os.path.exists(rel2id_path): raise FileNotFoundError('Cannot find rel2id: {}'.format(rel2id_path)) rel2id=json.load(open(rel2id_path,"r")) labels = [] src_tokens = [] src_lengths = [] src_idx = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if 'relation' in example: label = rel2id[example['relation']] labels.append(label) #bped=self.bpe.encode(" ".join(example["token"])) if getattr(self.args, 'bert', False): src_bin, pE1, pE2 = get_example_bert(example) else: bped, pE1, pE2 = getExample(example,1) if pE1==-1: continue src_bin = binarize(bped, append_bos=True) src_tokens.append(src_bin) src_lengths.append(len(src_bin)) #pE1=0 #pE2=0 src_idx.append([[pE1 for i in range(0,self.args.encoder_embed_dim)], [pE2 for i in range(0,self.args.encoder_embed_dim)]]) src_lengths = np.array(src_lengths) src_tokens = ListDataset(src_tokens, src_lengths) src_lengths = ListDataset(src_lengths) print("src_len", len(src_lengths)) print("src_tokens", len(src_tokens)) dataset = { 'id': IdDataset(), 'net_input':{ 'src_tokens':RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad() ), 'src_lengths': src_lengths, }, 'index': RawLabelDataset(src_idx), 'target': RawLabelDataset(labels), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } dataset = NestedDictionaryDataset( dataset, sizes=src_tokens.sizes, ) with data_utils.numpy_seed(self.args.seed+epoch): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def get_pad_dataset_fn(tokens, length, pad_idx): return PadDataset( ListDataset(tokens, length), pad_idx=pad_idx, left_pad=False, )
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + ".jsonl") if not os.path.exists(data_path): raise FileNotFoundError("Cannot find data: {}".format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator( data_path): prefix = sentence[:pronoun_span.start].text suffix = sentence[pronoun_span.end:].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE leading_space = ( " " if sentence[:pronoun_span.start].text_with_ws.endswith(" ") else "") trailing_space = " " if pronoun_span.text_with_ws.endswith( " ") else "" # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( wsc_utils.extended_noun_chunks(sentence), exclude_pronouns=True, exclude_query=query, exact_match=False, ) if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = self.binarize_with_mask( cand_span.text, prefix, suffix, leading_space, trailing_space, ) cand_toks.append(toks) cand_masks.append(mask) # collate candidates cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) assert cand_toks.size() == cand_masks.size() candidate_tokens.append(cand_toks) candidate_masks.append(cand_masks) candidate_lengths.append(cand_toks.size(1)) labels.append(label) query_lengths = np.array(query_lengths) query_tokens = ListDataset(query_tokens, query_lengths) query_masks = ListDataset(query_masks, query_lengths) candidate_lengths = np.array(candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) labels = ListDataset(labels, [1] * len(labels)) dataset = { "id": IdDataset(), "query_tokens": query_tokens, "query_masks": query_masks, "candidate_tokens": candidate_tokens, "candidate_masks": candidate_masks, "labels": labels, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def lang_pair_dataset(lengths: Sequence[int]) -> LanguagePairDataset: tokens = [[i] * l for i, l in enumerate(lengths)] return LanguagePairDataset(ListDataset(tokens), lengths, mock_dict())
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print('Loading dataset') data_path = os.path.join(self.args.data) dataset_inst = data_utils.load_indexed_dataset( os.path.join(data_path, 'insts', split), self.instruction_dictionary, self.args.dataset_impl, combine=combine, ) dataset_state = data_utils.load_indexed_dataset( os.path.join(data_path, 'states', split), self.state_dictionary, self.args.dataset_impl, combine=combine, ) if dataset_inst is None or dataset_state is None: raise FileNotFoundError('Dataset not found: {}'.format(split)) dataset_inst = SeqOfSeqDataset(dataset_inst, self.instruction_dictionary) dataset_state = SeqOfSeqDataset(dataset_state, self.state_dictionary) dataset_pos = IRPositionDataset(os.path.join(data_path, 'pos', split)) dataset = IRDataset(dataset_inst, dataset_state, dataset_pos) block_size = self.args.function_length dataset = IRPadDataset( dataset, inst_pad_idx=self.instruction_dictionary.pad(), state_pad_idx=self.state_dictionary.pad(), inst_mask_idx=self.inst_mask_idx, state_mask_idx=self.state_mask_idx, inst_cls_idx=self.instruction_dictionary.bos(), state_cls_idx=self.state_dictionary.bos(), smallbert_insts_per_input=self.args.smallbert_insts_per_group, smallbert_states_per_input=self.args.smallbert_insts_per_group, max_length=block_size, inst_pad_length=32, state_pad_length=16, pair=True, ) labels_str = list(map(json.loads, open(os.path.join(data_path, 'label', split + '.txt')))) labels = torch.tensor([x - 1 if isinstance(x, int) else int(x.strip()) - 1 for x in labels_str]) #function_indices = [torch.tensor(json.loads(x)) for x in open(os.path.join(data_path, 'funcs', split + '.txt'))] #dataset = IRMultiFunctionDataset(dataset, function_indices, self.args.max_functions_per_program) print('| loaded {} batches from: {} and {}'.format(len(dataset), os.path.join(data_path, 'insts', split), os.path.join(data_path, 'states', split))) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.labels[split] = SortDataset(RawLabelDataset(labels), sort_order=[shuffle]) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src': dataset, }, 'target': RawLabelDataset(labels), 'indices': RawLabelDataset(torch.arange(len(dataset))), 'subset': ListDataset([split for _ in range(len(dataset))]) }, sizes=[dataset.sizes], ), sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" ###encoder 객체 생성 bpe_encoder = MultiprocessingEncoder(self.args.encoder_json, self.args.vocab_bpe) bpe_encoder.initializer() ###preprocess_coqa부르기 examples, features = get_CoQA_features(self.args, bpe_encoder, self.args.init_token, self.args.separator_token, self.dictionary.pad(), split=split) self.examples[split] = examples self.features[split] = features qas_idx = [] src_tokens = [] src_lengths = [] padding_mask = [] start_pos = [] end_pos = [] is_unk = [] is_yes = [] is_no = [] number = [] option = [] for feature in features: src = torch.IntTensor(feature.input_tokens).long() p_mask = torch.IntTensor(feature.p_mask).long() src_tokens.append(src) src_lengths.append(len(src)) padding_mask.append(p_mask) qas_idx.append(feature.qas_id) start_pos.append(feature.start_position) end_pos.append(feature.end_position) is_unk.append(feature.is_unk) is_yes.append(feature.is_yes) is_no.append(feature.is_no) number.append(feature.number) option.append(feature.option) src_tokens = ListDataset(src_tokens, src_lengths) src_lengths = ListDataset(src_lengths) dataset = { "id": IdDataset(), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "qas_id": RawLabelDataset(qas_idx), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.dictionary.pad()), "src_lengths": src_lengths, "start_position": RawLabelDataset(start_pos), "p_mask": RightPadDataset(padding_mask, pad_idx=self.dictionary.pad()), }, "start_position": RawLabelDataset(start_pos), "end_position": RawLabelDataset(end_pos), "is_unk": RawLabelDataset(is_unk), "is_yes": RawLabelDataset(is_yes), "is_no": RawLabelDataset(is_no), "number": RawLabelDataset(number), "option": RawLabelDataset(option), } dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_tokens.sizes])], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, sort_order=[np.random.permutation(len(dataset))], ) print("| Loaded {} with {} samples".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat( [tokens.new([self.args.init_token]), tokens]) return tokens # self.data_path_table={'train_input':os.path.join(self.args.data,'Training Data','subtaskA_data_all.csv'),\ # 'train_answer':os.path.join(self.args.data,'Training Data','subtaskA_answers_all.csv'),\ # 'valid_input':os.path.join(self.args.data,'Trial Data','taskA_trial_data.csv'),\ # 'valid_answer':os.path.join(self.args.data,'Trial Data','taskA_trial_answer.csv')\ # } # self.data_path_table={'train_input':os.path.join(self.args.data,'trainval','subtaskA_data_all.csv'),\ # 'train_answer':os.path.join(self.args.data,'trainval','subtaskA_answers_all.csv'),\ # 'valid_input':os.path.join(self.args.data,'Dev Data','subtaskA_dev_data.csv'),\ # 'valid_answer':os.path.join(self.args.data,'Dev Data','subtaskA_gold_answers.csv')\ # } self.data_path_table={'train_input':os.path.join(self.args.data,'trainvaldev','subtaskA_data_all_plusplus.csv'),\ 'train_answer':os.path.join(self.args.data,'trainvaldev','subtaskA_answers_all.csv'),\ 'valid_input':os.path.join(self.args.data,'Dev Data','subtaskA_dev_data_plusplus.csv'),\ 'valid_answer':os.path.join(self.args.data,'Dev Data','subtaskA_gold_answers.csv')\ } # self.data_path_table={'train_input':os.path.join(self.args.data,'subtaskA_data_all.csv'),\ # 'train_answer':os.path.join(self.args.data,'subtaskA_answers_all.csv'),\ # 'valid_input':os.path.join(self.args.data,'taskA_trial_data.csv'),\ # 'valid_answer':os.path.join(self.args.data,'taskA_trial_answer.csv')\ # } data_path_input = self.data_path_table[split + '_input'] data_path_answer = self.data_path_table[split + '_answer'] if not os.path.exists(data_path_input): raise FileNotFoundError( 'Cannot find data: {}'.format(data_path_input)) if not os.path.exists(data_path_answer): raise FileNotFoundError( 'Cannot find data: {}'.format(data_path_answer)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] src_ids = [] labels = [] label_ids = [] with open(data_path_input) as f: reader = csv.reader(f) for row in islice(reader, 1, None): src_ids.append(row[0]) for i in range(self.args.num_classes): src = row[i + 1] evidence = row[i + 3] if src.isupper(): src = src.capitalize() src = src + ' Context: ' + evidence src_bin = binarize(src, append_bos=True) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all( len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) with open(data_path_answer) as f: reader = csv.reader(f) for row in reader: label_ids.append(row[0]) label = 1 - int(row[1]) labels.append(label) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) assert all(src_ids[i] == label_ids[i] for i in range(len(src_ids))) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[ np.maximum.reduce( [src_token.sizes for src_token in src_tokens]) ], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]