def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) input2 = make_dataset('input2', self.source_dictionary) input3 = make_dataset('input3', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 assert False else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) input2 = PrependTokenDataset(input2, self.args.separator_token) input3 = PrependTokenDataset(input3, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1, input2, input3) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset['net_input'].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset('label', self.label_dictionary) if label_dataset is not None: dataset.update( target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, ) ) else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert len(values) == self.args.num_classes, \ f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] dataset.update( target=RawLabelDataset([ parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines()) ]) ) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset src_tokens = make_dataset('data', self.source_dictionary) assert src_tokens is not None, 'could not find dataset: {}'.format( get_path('data', split)) src_tokens = TruncateDataset(src_tokens, self.args.max_positions) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': src_tokens, 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } src_labels = make_dataset('label', self.target_dictionary) assert src_labels is not None, 'could not find dataset: {}'.format( get_path('label', split)) src_labels = TruncateDataset(src_labels, self.args.max_positions) src_labels = OffsetTokensDataset( src_labels, offset=-self.target_dictionary.nspecial, ) dataset.update(target=src_labels) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if not self.args.regression_target: label_dataset = make_dataset('label', self.target_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.target_dictionary.eos(), ), offset=-self.target_dictionary.nspecial, )) else: label_path = f"{get_path('label', split)}.label" if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print(f"| Loaded {split} with #samples: {len(dataset)}") self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] labels = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if 'answerKey' in example: label = ord(example['answerKey']) - ord('A') labels.append(label) question = example['question']['stem'] assert len(example['question']['choices']) == self.args.num_classes # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` question = 'Q: ' + question question_toks = binarize(question, append_bos=True) for i, choice in enumerate(example['question']['choices']): src = 'A: ' + choice['text'] src_bin = torch.cat([question_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / "{split}".format(split=split) src_tokens = data_utils.load_indexed_dataset( str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine, ) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) targets_path = Path(self.args.data) / "{}.nonterm".format(split) labelled_spans = data_utils.load_indexed_dataset( str(targets_path), self.label_dictionary, self.args.dataset_impl, combine=combine, ) assert labelled_spans is not None, "could not find labels: {}".format( targets_path) raise NotImplementedError target_spans = LabelledSpanDataset(labelled_spans, return_spans=True) labels = LabelledSpanDataset(labelled_spans, return_spans=False) # all possible word spans in each sequence word_spans = WordSpanDataset(src_tokens, self.source_dictionary, self.is_word_initial) all_spans = ProductSpanDataset(word_spans) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "src_spans": RightPadDataset(all_spans, pad_idx=self.label_dictionary.pad()), "nsrc_spans": NumSpanDataset(all_spans), }, "targets": RightPadDataset(labels, pad_idx=self.label_dictionary.pad()), "target_spans": RightPadDataset(target_spans, pad_idx=self.label_dictionary.pad()), "ntargets": NumelDataset(labels), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), "word_spans": RightPadDataset(word_spans, pad_idx=self.label_dictionary.pad()), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens # self.data_path_table={'train_input':os.path.join(self.args.data,'Training Data','subtaskB_data_all.csv'),\ # 'train_answer':os.path.join(self.args.data,'Training Data','subtaskB_answers_all.csv'),\ # 'valid_input':os.path.join(self.args.data,'Trial Data','taskB_trial_data.csv'),\ # 'valid_answer':os.path.join(self.args.data,'Trial Data','taskB_trial_answer.csv')\ # } self.data_path_table={'train_input':os.path.join(self.args.data,'trainval','subtaskB_data_all_new_plusplus.csv'),\ 'train_answer':os.path.join(self.args.data,'trainval','subtaskB_answers_all.csv'),\ 'valid_input':os.path.join(self.args.data,'Dev Data','subtaskB_dev_data_new_plusplus.csv'),\ 'valid_answer':os.path.join(self.args.data,'Dev Data','subtaskB_gold_answers.csv')\ } data_path_input=self.data_path_table[split+'_input'] data_path_answer=self.data_path_table[split+'_answer'] if not os.path.exists(data_path_input): raise FileNotFoundError('Cannot find data: {}'.format(data_path_input)) if not os.path.exists(data_path_answer): raise FileNotFoundError('Cannot find data: {}'.format(data_path_answer)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] src_ids = [] labels = [] label_ids = [] with open(data_path_input) as f: reader=csv.reader(f) for row in islice(reader,1,None): src_ids.append(row[0]) statement = row[1] trueSent = row[5] wiktionary = row[6] if statement.isupper(): statement = statement.capitalize() #statement = 'The statement "'+ statement + '" is absurd.' #statement = 'Reasonable statement: ' + trueSent + ' | The statement "'+ statement + '" is absurd.' #statement = 'Context: '+wiktionary+' | The statement "'+ statement + '" is absurd.' statement = 'Context: '+wiktionary+'\ Reasonable statement: ' + trueSent + ' | The statement "'+ statement + '" is absurd.' statement_toks = binarize(statement,append_bos=True) for i in range(self.args.num_classes): src = row[i+2] if src.isupper(): src = src.capitalize() src = 'Because ' + src #src = ' '+src src_bin = torch.cat([statement_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) with open(data_path_answer) as f: reader=csv.reader(f) for row in reader: label_ids.append(row[0]) label = ord(row[1]) - ord('A') labels.append(label) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) assert all(src_ids[i] == label_ids[i] for i in range(len(src_ids))) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test')) for sample in itr: sentence, pronoun_span, query, cand_text = sample prefix = sentence[:pronoun_span[0]].rstrip() suffix = sentence[pronoun_span[1]:] leading_space = ' ' if sentence[:pronoun_span[0]].endswith( ' ') else '' trailing_space = '' if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space, ) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_mask = self.binarize_with_mask( cand_text, prefix, suffix, leading_space, trailing_space, ) candidate_tokens.append(cand_toks) candidate_masks.append(cand_mask) candidate_lengths.append(cand_toks.size(0)) query_lengths = np.array(query_lengths) def get_pad_dataset_fn(tokens, length, pad_idx): return PadDataset( ListDataset(tokens, length), pad_idx=pad_idx, left_pad=False, ) query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad()) query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0) candidate_lengths = np.array(candidate_lengths) candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad()) candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator( data_path): prefix = sentence[:pronoun_span.start].text suffix = sentence[pronoun_span.end:].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE leading_space = ' ' if sentence[:pronoun_span. start].text_with_ws.endswith( ' ') else '' trailing_space = ' ' if pronoun_span.text_with_ws.endswith( ' ') else '' # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( wsc_utils.extended_noun_chunks(sentence), exclude_pronouns=True, exclude_query=query, exact_match=False, ) if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = self.binarize_with_mask( cand_span.text, prefix, suffix, leading_space, trailing_space, ) cand_toks.append(toks) cand_masks.append(mask) # collate candidates cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) assert cand_toks.size() == cand_masks.size() candidate_tokens.append(cand_toks) candidate_masks.append(cand_masks) candidate_lengths.append(cand_toks.size(1)) labels.append(label) query_lengths = np.array(query_lengths) query_tokens = ListDataset(query_tokens, query_lengths) query_masks = ListDataset(query_masks, query_lengths) candidate_lengths = np.array(candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) labels = ListDataset(labels, [1] * len(labels)) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'labels': labels, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) assert dataset is not None, "could not find dataset: {}".format( get_path(type, split)) return dataset src_tokens = make_dataset("input0", self.source_dictionary) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) label_dataset = make_dataset("label", self.label_dictionary) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "target": RightPadDataset( # use 1 as padding, will be used to mask out padding when calculating loss ReplaceDataset( # replace eos and existing padding (used when some tokens should not be predicted) with -1 OffsetTokensDataset( # offset tokens to get the targets to the correct range (0,1,2,...) label_dataset, offset=-self.label_dictionary.nspecial, ), replace_map={ self.label_dictionary.eos() - self.label_dictionary.nspecial: -1, self.label_dictionary.pad() - self.label_dictionary.nspecial: -1, }, offsets=np.zeros(len(label_dataset), dtype=np.int), ), pad_idx=-1, ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) #+ '.bpe' dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) # load counts thresh = 100 with open(split_path + '.counts') as count_file: lines = [line.rstrip() for line in count_file] counts = [line.split(' ') for line in lines] for i, count in enumerate(counts): count = [int(el) for el in count] counts[i] = [el if el < thresh else thresh for el in count] counts[i] = torch.LongTensor(np.concatenate([[0],counts[i],[0]])) # load embeddings if not self.args.input_format=='tokens': embs = torch.load(split_path + '.features') # mask counts and embeddings for i, data in enumerate(src_dataset): counts[i] = counts[i] * (data != self.mask_idx) embs[i] = embs[i] * (data != self.mask_idx)[1:-1, None] self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_counts': PadDataset( counts, pad_idx=0, left_pad=False, ), 'src_embs': EmbeddingDataset( embs, pad_idx=0, left_pad=False, ) if not self.args.input_format=='tokens' else None, 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) input_options = [ make_dataset( 'input{idx}'.format(idx=idx + 1), self.source_dictionary ) for idx in range(self.args.num_classes) ] if self.args.separator_token is not None: input0 = PrependTokenDataset(input0, self.args.separator_token) src_tokens = [] for input_option in input_options: if self.args.init_token is not None: input_option = PrependTokenDataset(input_option, self.args.init_token) if self.args.max_option_length is not None: input_option = TruncateDataset(input_option, self.args.max_option_length) src_token = ConcatSentencesDataset(input_option, input0) if self.args.truncate_sequence: src_token = TruncateDataset(src_token, self.args.max_positions) src_tokens.append(src_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens[0])) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for src_token_idx in range(len(src_tokens)): dataset.update( { 'net_input{idx}'.format(idx=src_token_idx+1): { 'src_tokens': RightPadDataset( src_tokens[src_token_idx], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False), } } ) label_path = '{}.label'.format(get_path('label', split)) if os.path.exists(label_path): with open(label_path) as h: dataset.update( target=RawLabelDataset([ int(x.strip()) for x in h.readlines() ]) ) nested_dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]