def lang_dataset(lang): input0 = make_dataset('input0', lang, self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path('input0', lang, split)) input1 = make_dataset('input1', lang, self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if not self.args.regression_target: label_dataset = make_dataset('label', lang, self.target_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.target_dictionary.eos(), ), offset=-self.target_dictionary.nspecial, )) else: label_path = "{0}.label".format(get_path('label', lang, split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset([ float(x.strip()) for x in open(label_path).readlines() ])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format( split, len(dataset))) return dataset
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset['net_input'].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset('label', self.target_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.target_dictionary.eos(), ), offset=-self.target_dictionary.nspecial, )) else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) #+ '.bpe' dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) # load counts thresh = 100 with open(split_path + '.counts') as count_file: lines = [line.rstrip() for line in count_file] counts = [line.split(' ') for line in lines] for i, count in enumerate(counts): count = [int(el) for el in count] counts[i] = [el if el < thresh else thresh for el in count] counts[i] = torch.LongTensor(np.concatenate([[0],counts[i],[0]])) # load embeddings if not self.args.input_format=='tokens': embs = torch.load(split_path + '.features') # mask counts and embeddings for i, data in enumerate(src_dataset): counts[i] = counts[i] * (data != self.mask_idx) embs[i] = embs[i] * (data != self.mask_idx)[1:-1, None] self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_counts': PadDataset( counts, pad_idx=0, left_pad=False, ), 'src_embs': EmbeddingDataset( embs, pad_idx=0, left_pad=False, ) if not self.args.input_format=='tokens' else None, 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 src_tokens = {} tgt_tokens = {} tgt_values = {} for field in configs.fields: split_path = os.path.join(self.args.data, field, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary[field], self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary[field].pad(), eos=self.source_dictionary[field].eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary[field].bos()) if field == configs.static_field: src_dataset_code, tgt_dataset_code = MaskTokensDataset.apply_mask( dataset, self.source_dictionary[field], pad_idx=self.source_dictionary[field].pad(), mask_idx=self.mask_idx_dict[field], seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, ) src_tokens[field] = RightPadDataset( src_dataset_code, pad_idx=self.source_dictionary[field].pad() ) tgt_tokens[field] = RightPadDataset( tgt_dataset_code, pad_idx=self.source_dictionary[field].pad() ) elif field in configs.byte_fields: src_dataset_value, tgt_dataset_value = MaskValuesDataset.apply_mask( dataset, self.source_dictionary[field], pad_idx=self.source_dictionary[field].pad(), mask_idx=self.mask_idx_dict[field], seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, ) src_tokens[field] = RightPadDataset( src_dataset_value, pad_idx=self.source_dictionary[field].pad() ) # dummy tokens are treated as 1 # TODO: assert there should not be any dummy tokens here tgt_values[field] = BytevalueDataset(tgt_dataset_value, self.source_dictionary[field]) else: src_tokens[field] = RightPadDataset( dataset, pad_idx=self.source_dictionary[field].pad() ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset_code)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": src_tokens, "src_lengths": NumelDataset(src_dataset_code, reduce=False), }, "target": { "tgt_tokens": tgt_tokens, "tgt_values": tgt_values }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset_code, reduce=True), }, sizes=[src_dataset_code.sizes], ), sort_order=[ shuffle, src_dataset_code.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator( data_path): prefix = sentence[:pronoun_span.start].text suffix = sentence[pronoun_span.end:].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE leading_space = ' ' if sentence[:pronoun_span. start].text_with_ws.endswith( ' ') else '' trailing_space = ' ' if pronoun_span.text_with_ws.endswith( ' ') else '' # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( wsc_utils.extended_noun_chunks(sentence), exclude_pronouns=True, exclude_query=query, exact_match=False, ) if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = self.binarize_with_mask( cand_span.text, prefix, suffix, leading_space, trailing_space, ) cand_toks.append(toks) cand_masks.append(mask) # collate candidates cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) assert cand_toks.size() == cand_masks.size() candidate_tokens.append(cand_toks) candidate_masks.append(cand_masks) candidate_lengths.append(cand_toks.size(1)) labels.append(label) query_lengths = np.array(query_lengths) query_tokens = ListDataset(query_tokens, query_lengths) query_masks = ListDataset(query_masks, query_lengths) candidate_lengths = np.array(candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) labels = ListDataset(labels, [1] * len(labels)) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'labels': labels, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print('Loading dataset') data_path = os.path.join(self.args.data) dataset_inst = data_utils.load_indexed_dataset( os.path.join(data_path, 'insts', split), self.instruction_dictionary, self.args.dataset_impl, combine=combine, ) dataset_state = data_utils.load_indexed_dataset( os.path.join(data_path, 'states', split), self.state_dictionary, self.args.dataset_impl, combine=combine, ) if dataset_inst is None or dataset_state is None: raise FileNotFoundError('Dataset not found: {}'.format(split)) dataset_inst = SeqOfSeqDataset(dataset_inst, self.instruction_dictionary) dataset_state = SeqOfSeqDataset(dataset_state, self.state_dictionary) dataset_pos = IRPositionDataset(os.path.join(data_path, 'pos', split)) dataset = IRDataset(dataset_inst, dataset_state, dataset_pos) block_size = self.args.function_length dataset = IRPadDataset( dataset, inst_pad_idx=self.instruction_dictionary.pad(), state_pad_idx=self.state_dictionary.pad(), inst_mask_idx=self.inst_mask_idx, state_mask_idx=self.state_mask_idx, inst_cls_idx=self.instruction_dictionary.bos(), state_cls_idx=self.state_dictionary.bos(), smallbert_insts_per_input=self.args.smallbert_insts_per_group, smallbert_states_per_input=self.args.smallbert_insts_per_group, max_length=block_size, inst_pad_length=32, state_pad_length=16, pair=True, ) labels_str = list(map(json.loads, open(os.path.join(data_path, 'label', split + '.txt')))) labels = torch.tensor([x - 1 if isinstance(x, int) else int(x.strip()) - 1 for x in labels_str]) #function_indices = [torch.tensor(json.loads(x)) for x in open(os.path.join(data_path, 'funcs', split + '.txt'))] #dataset = IRMultiFunctionDataset(dataset, function_indices, self.args.max_functions_per_program) print('| loaded {} batches from: {} and {}'.format(len(dataset), os.path.join(data_path, 'insts', split), os.path.join(data_path, 'states', split))) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.labels[split] = SortDataset(RawLabelDataset(labels), sort_order=[shuffle]) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src': dataset, }, 'target': RawLabelDataset(labels), 'indices': RawLabelDataset(torch.arange(len(dataset))), 'subset': ListDataset([split for _ in range(len(dataset))]) }, sizes=[dataset.sizes], ), sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 if split != getattr(self.cfg, "train_subset", None): # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.cfg.source_lang, self.cfg.target_lang prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) src_dataset = data_utils.load_indexed_dataset(prefix + src, self.src_dict, self.cfg.dataset_impl) tag_dataset = data_utils.load_indexed_dataset(prefix + tgt, self.tag_dict, self.cfg.dataset_impl) src_dataset = StripTokenDataset( src_dataset, id_to_strip=self.source_dictionary.eos()) tag_dataset = StripTokenDataset(tag_dataset, id_to_strip=self.tag_dictionary.eos()) tag_pad = self.source_dictionary.pad() tag_offset = tag_pad + 1 dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset(src_dataset, pad_idx=self.source_dictionary.pad()), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), 'target': RightPadDataset( OffsetTokensDataset(tag_dataset, offset=-self.tag_dictionary.nspecial + tag_offset), pad_idx=tag_pad, ), } dataset = NestedDictionaryDataset( dataset, sizes=[src_dataset.sizes], ) logger.info( str([self.src_dict[k] for k in dataset[0]['net_input.src_tokens']])) logger.info( str([ self.tag_dict[k + self.tag_dictionary.nspecial - tag_offset] for k in dataset[0]['target'] ])) self.datasets[split] = dataset
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset # inputs are loaded similarly to sentence_prediction input0 = make_dataset("input0", self.source_dictionary) # question input1 = make_dataset("input1", self.source_dictionary) # context # src_tokens: <init_token> input0 <separator_token> input1 <eos_token> if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) if self.args.max_context_length is not None: input1 = TruncateDataset(input1, self.args.max_option_length) src_tokens = ConcatSentencesDataset(input0, input1) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_tokens, reduce=False), "input0_lengths": NumelDataset( input0, reduce=False ), # question length (init_token possibly included) }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } # labels (spans) are loaded similarly to sentence_ranking label_path = "{}.label".format(get_path("label", split)) def _process_label(positions, input0_length, truncate_sequence, max_positions): """Process a span [start:end] to the input range. After processing, tokens can be accessed by tokens[start:end+1]. TODO: change inputs to reflect this change in the first place. """ start, end = [ pos + input0_length + (self.args.separator_token is not None) for pos in positions ] end -= 1 # [0, 511] if truncate_sequence: if start >= max_positions: start, end = max_positions - 1, max_positions - 1 # not predictable elif end >= max_positions: end = max_positions - 1 return start, end if os.path.exists(label_path): with open(label_path) as h: dataset.update(target=RawLabelDataset([ _process_label( tuple(int(pos) for pos in x.split()), dataset["net_input"]["input0_lengths"][i], self.args.truncate_sequence, self.max_positions(), ) for i, x in enumerate( h.readlines()) # (start_position, end_position) ])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat( [tokens.new([self.args.init_token]), tokens]) return tokens # self.data_path_table={'train_input':os.path.join(self.args.data,'Training Data','subtaskA_data_all.csv'),\ # 'train_answer':os.path.join(self.args.data,'Training Data','subtaskA_answers_all.csv'),\ # 'valid_input':os.path.join(self.args.data,'Trial Data','taskA_trial_data.csv'),\ # 'valid_answer':os.path.join(self.args.data,'Trial Data','taskA_trial_answer.csv')\ # } # self.data_path_table={'train_input':os.path.join(self.args.data,'trainval','subtaskA_data_all.csv'),\ # 'train_answer':os.path.join(self.args.data,'trainval','subtaskA_answers_all.csv'),\ # 'valid_input':os.path.join(self.args.data,'Dev Data','subtaskA_dev_data.csv'),\ # 'valid_answer':os.path.join(self.args.data,'Dev Data','subtaskA_gold_answers.csv')\ # } self.data_path_table={'train_input':os.path.join(self.args.data,'trainvaldev','subtaskA_data_all_plusplus.csv'),\ 'train_answer':os.path.join(self.args.data,'trainvaldev','subtaskA_answers_all.csv'),\ 'valid_input':os.path.join(self.args.data,'Dev Data','subtaskA_dev_data_plusplus.csv'),\ 'valid_answer':os.path.join(self.args.data,'Dev Data','subtaskA_gold_answers.csv')\ } # self.data_path_table={'train_input':os.path.join(self.args.data,'subtaskA_data_all.csv'),\ # 'train_answer':os.path.join(self.args.data,'subtaskA_answers_all.csv'),\ # 'valid_input':os.path.join(self.args.data,'taskA_trial_data.csv'),\ # 'valid_answer':os.path.join(self.args.data,'taskA_trial_answer.csv')\ # } data_path_input = self.data_path_table[split + '_input'] data_path_answer = self.data_path_table[split + '_answer'] if not os.path.exists(data_path_input): raise FileNotFoundError( 'Cannot find data: {}'.format(data_path_input)) if not os.path.exists(data_path_answer): raise FileNotFoundError( 'Cannot find data: {}'.format(data_path_answer)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] src_ids = [] labels = [] label_ids = [] with open(data_path_input) as f: reader = csv.reader(f) for row in islice(reader, 1, None): src_ids.append(row[0]) for i in range(self.args.num_classes): src = row[i + 1] evidence = row[i + 3] if src.isupper(): src = src.capitalize() src = src + ' Context: ' + evidence src_bin = binarize(src, append_bos=True) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all( len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) with open(data_path_answer) as f: reader = csv.reader(f) for row in reader: label_ids.append(row[0]) label = 1 - int(row[1]) labels.append(label) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) assert all(src_ids[i] == label_ids[i] for i in range(len(src_ids))) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[ np.maximum.reduce( [src_token.sizes for src_token in src_tokens]) ], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ dataset = self._load_dataset_split(split, epoch, combine) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.cfg.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.cfg.seed, mask_prob=self.cfg.mask_prob, leave_unmasked_prob=self.cfg.leave_unmasked_prob, random_token_prob=self.cfg.random_token_prob, freq_weighted_replacement=self.cfg.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.cfg.mask_multiple_length, mask_stdev=self.cfg.mask_stdev, ) with data_utils.numpy_seed(self.cfg.seed): shuffle = np.random.permutation(len(src_dataset)) target_dataset = RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ) input_dict = { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), } if self.cfg.include_target_tokens: input_dict["target_tokens"] = target_dataset self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": input_dict, "target": target_dataset, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset # input0 is source, input1 is synthetic target, input2 is reference input0 = make_dataset(self.args.input0, self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset(self.args.input1, self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if self.args.input2 is not None: input2 = make_dataset(self.args.input2, self.source_dictionary) if self.args.input2 is not None and self.add_ref_prob > 0 and split != 'valid': input3 = PrependTokenDataset(input2, self.args.separator_token) else: input3 = None if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) if self.args.input2 is not None and self.add_ref_prob > 0. and split != 'valid': src_tokens = ConcatSentencesDataset( input0, input3, input1, add_ref_prob=self.add_ref_prob, drop_ref_rate=self.args.dropout_ref, pad_idx=self.source_dictionary.pad(), eos_idx=self.source_dictionary.eos(), bos_idx=self.source_dictionary.bos()) else: src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) if self.args.input2 is not None and self.args.add_tran_loss: # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None ref_dataset, ref_target_dataset = MaskTokensDataset.apply_mask( input2, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) if self.args.separator_token is not None: input2 = PrependTokenDataset(ref_dataset, self.args.separator_token) parallel_src_tokens = ConcatSentencesDataset(input0, input2) if self.args.truncate_sequence: parallel_src_tokens = TruncateDataset(parallel_src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if self.args.input2 is not None and self.args.add_tran_loss: dataset['net_input']['parallel_src_tokens'] = RightPadDataset( parallel_src_tokens, pad_idx=self.source_dictionary.pad(), ) if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset['net_input'].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset('label', self.label_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, )) if self.args.input2 is not None and self.args.add_tran_loss: # used as translation target when calculating loss dataset.update(parallel_target=RightPadDataset( ref_target_dataset, pad_idx=self.source_dictionary.pad(), )) else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert len(values) == self.args.num_classes, \ f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] dataset.update(target=RawLabelDataset([ parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines()) ])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], all_sizes=src_tokens.all_sizes if self.args.add_target_num_tokens else None, padding_idx=self.source_dictionary.pad(), add_ref_prob=self.add_ref_prob if split != 'valid' else 0., ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset src_tokens = make_dataset('data', self.source_dictionary) assert src_tokens is not None, 'could not find dataset: {}'.format( get_path('data', split)) src_tokens = TruncateDataset(src_tokens, self.args.max_positions) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': src_tokens, 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } src_labels = make_dataset('label', self.target_dictionary) assert src_labels is not None, 'could not find dataset: {}'.format( get_path('label', split)) src_labels = TruncateDataset(src_labels, self.args.max_positions) src_labels = OffsetTokensDataset( src_labels, offset=-self.target_dictionary.nspecial, ) dataset.update(target=src_labels) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / "{split}".format(split=split) src_tokens = data_utils.load_indexed_dataset( str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine, ) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) targets_path = Path(self.args.data) / "{}.nonterm".format(split) labelled_spans = data_utils.load_indexed_dataset( str(targets_path), self.label_dictionary, self.args.dataset_impl, combine=combine, ) assert labelled_spans is not None, "could not find labels: {}".format( targets_path) raise NotImplementedError target_spans = LabelledSpanDataset(labelled_spans, return_spans=True) labels = LabelledSpanDataset(labelled_spans, return_spans=False) # all possible word spans in each sequence word_spans = WordSpanDataset(src_tokens, self.source_dictionary, self.is_word_initial) all_spans = ProductSpanDataset(word_spans) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "src_spans": RightPadDataset(all_spans, pad_idx=self.label_dictionary.pad()), "nsrc_spans": NumSpanDataset(all_spans), }, "targets": RightPadDataset(labels, pad_idx=self.label_dictionary.pad()), "target_spans": RightPadDataset(target_spans, pad_idx=self.label_dictionary.pad()), "ntargets": NumelDataset(labels), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), "word_spans": RightPadDataset(word_spans, pad_idx=self.label_dictionary.pad()), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) input_options = [ make_dataset( 'input{idx}'.format(idx=idx + 1), self.source_dictionary ) for idx in range(self.args.num_classes) ] if self.args.separator_token is not None: input0 = PrependTokenDataset(input0, self.args.separator_token) src_tokens = [] for input_option in input_options: if self.args.init_token is not None: input_option = PrependTokenDataset(input_option, self.args.init_token) if self.args.max_option_length is not None: input_option = TruncateDataset(input_option, self.args.max_option_length) src_token = ConcatSentencesDataset(input_option, input0) if self.args.truncate_sequence: src_token = TruncateDataset(src_token, self.args.max_positions) src_tokens.append(src_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens[0])) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for src_token_idx in range(len(src_tokens)): dataset.update( { 'net_input{idx}'.format(idx=src_token_idx+1): { 'src_tokens': RightPadDataset( src_tokens[src_token_idx], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False), } } ) label_path = '{}.label'.format(get_path('label', split)) if os.path.exists(label_path): with open(label_path) as h: dataset.update( target=RawLabelDataset([ int(x.strip()) for x in h.readlines() ]) ) nested_dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.args.mask_multiple_length, mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) # remove tail dataset = RemoveTailDataset(dataset) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(key, split): return os.path.join(self.args.data, key, split) def make_dataset(key, dictionary): split_path = get_path(key, split) try: dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) except Exception as e: if "StorageException: [404] Path not found" in str(e): logger.warning(f"dataset {e} not found") dataset = None else: raise e return dataset input0 = make_dataset("input0", self.source_dictionary) assert input0 is not None, "could not find dataset: {}".format( get_path("input0", split) ) input1 = make_dataset("input1", self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, self.args.shorten_data_split_list, self.args.shorten_method, self.max_positions(), self.args.seed, ) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_number_token_values": NumberValueDataset( src_tokens, vocab=self.dictionary, tokenizer=self.tokenizer, number_value_cutoff=self.args.number_value_cutoff, send_log_value=self.args.send_log_value ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset["net_input"].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset("label", self.label_dictionary) if label_dataset is not None: dataset.update( target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, ) ) else: label_path = "{0}.label".format(get_path("label", split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert ( len(values) == self.args.num_classes ), f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] with open(label_path) as h: dataset.update( target=RawLabelDataset( [ parse_regression_target(i, line.strip()) for i, line in enumerate(h.readlines()) ] ) ) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets if self.args.mask_whole_words: bpe = encoders.build_bpe(self.args) assert bpe is not None def is_beginning_of_word(i): if i < self.source_dictionary.nspecial: # special elements are always considered beginnings return True tok = self.source_dictionary[i] if tok.startswith('madeupword'): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True mask_whole_words = torch.ByteTensor( list( map(is_beginning_of_word, range(len(self.source_dictionary))))) else: mask_whole_words = None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" ###encoder 객체 생성 bpe_encoder = MultiprocessingEncoder(self.args.encoder_json, self.args.vocab_bpe) bpe_encoder.initializer() ###preprocess_coqa부르기 examples, features = get_CoQA_features(self.args, bpe_encoder, self.args.init_token, self.args.separator_token, self.dictionary.pad(), split=split) self.examples[split] = examples self.features[split] = features qas_idx = [] src_tokens = [] src_lengths = [] padding_mask = [] start_pos = [] end_pos = [] is_unk = [] is_yes = [] is_no = [] number = [] option = [] for feature in features: src = torch.IntTensor(feature.input_tokens).long() p_mask = torch.IntTensor(feature.p_mask).long() src_tokens.append(src) src_lengths.append(len(src)) padding_mask.append(p_mask) qas_idx.append(feature.qas_id) start_pos.append(feature.start_position) end_pos.append(feature.end_position) is_unk.append(feature.is_unk) is_yes.append(feature.is_yes) is_no.append(feature.is_no) number.append(feature.number) option.append(feature.option) src_tokens = ListDataset(src_tokens, src_lengths) src_lengths = ListDataset(src_lengths) dataset = { "id": IdDataset(), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "qas_id": RawLabelDataset(qas_idx), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.dictionary.pad()), "src_lengths": src_lengths, "start_position": RawLabelDataset(start_pos), "p_mask": RightPadDataset(padding_mask, pad_idx=self.dictionary.pad()), }, "start_position": RawLabelDataset(start_pos), "end_position": RawLabelDataset(end_pos), "is_unk": RawLabelDataset(is_unk), "is_yes": RawLabelDataset(is_yes), "is_no": RawLabelDataset(is_no), "number": RawLabelDataset(number), "option": RawLabelDataset(option), } dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_tokens.sizes])], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, sort_order=[np.random.permutation(len(dataset))], ) print("| Loaded {} with {} samples".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path('input0', split)) input1 = make_dataset('input1', self.source_dictionary) assert input1 is not None, 'could not find dataset: {}'.format( get_path('input1', split)) assert len(input0) == len(input1), 'input pair different length' if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) input1 = PrependTokenDataset(input1, self.args.init_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(input0)) if self.args.truncate_sequence: input0 = TruncateDataset(input0, self.args.max_positions) input1 = TruncateDataset(input1, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input0': { 'src_tokens': RightPadDataset( input0, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(input0, reduce=False), }, 'net_input1': { 'src_tokens': RightPadDataset( input1, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(input1, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens0': NumelDataset(input0, reduce=True), 'ntokens1': NumelDataset(input1, reduce=True), } label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum(input0.sizes, input1.sizes)], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): cache = os.path.join( self.args.data, "cached_{}_{}_{}.pth".format(split, self.args.bpe, self.args.max_seq_len)) if os.path.exists(cache): examples, features = torch.load(cache) else: if split == 'valid': examples = self.processor.get_dev_examples( self.args.data, self.train_or_dev_file[split]) else: examples = self.processor.get_train_examples( self.args.data, self.train_or_dev_file[split]) features = squad_convert_examples_to_features( examples=examples, tokenizer=self.tokenizer, max_seq_length=self.args.max_seq_len, doc_stride=128, max_query_length=64, is_training=(split != 'valid'), return_dataset=False, ) if self.args.distributed_rank == 0: torch.save((examples, features), cache) if split == 'valid' and self.do_evaluate: self.examples = examples self.features = features src_dataset = BaseWrapperDataset( [np.array(f.input_ids) for f in features]) starts = BaseWrapperDataset( np.array([f.start_position for f in features])) ends = BaseWrapperDataset(np.array([f.end_position for f in features])) sizes = np.array([len(f.input_ids) for f in features]) src_lengths = NumelDataset(src_dataset, reduce=False) ''' Input format: <s> question here ? </s> Passage </s> ''' dataset = NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': src_dataset, 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'targets': { 'starts': starts, 'ends': ends, }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[sizes], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s, append_bos=False, append_eos=True): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=append_eos, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] labels = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if 'answerKey' in example: label = ord(example['answerKey']) - ord('A') labels.append(label) question = example['question']['stem'] assert len(example['question']['choices']) == self.args.num_classes # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` question = 'Q: ' + question.strip() question_toks = binarize(question, append_bos=True) #for i in range(5): # if 'cose{}'.format(i) in example['question'].keys(): # explanation += ' ' + example['question']['cose{}'.format(i)].strip() explanation_toks = binarize(example['question']['cose'], append_eos=False, append_bos=False) for i, choice in enumerate(example['question']['choices']): src = 'A: ' + choice['text'].strip() #if 'cose' in example['question'].keys(): src_bin = torch.cat([question_toks, explanation_toks, binarize(src, append_bos=False,append_eos=True)]) #else: # src_bin = torch.cat([question_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(key, split): return os.path.join(self.args.data, key, split) def make_dataset(key, dictionary): split_path = get_path(key, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset("input0", self.source_dictionary) assert input0 is not None, "could not find dataset: {}".format( get_path("input0", split)) input1 = make_dataset("input1", self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.max_positions, self.args.seed, ) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset["net_input"].update( prev_output_tokens=prev_tokens_dataset, ) label_path = "{0}.npz".format(get_path("label", split)) if os.path.exists(label_path): csr_matrix = load_npz(label_path) dataset.update(target=CSRLabelDataset(csr_matrix)) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test')) for sample in itr: sentence, pronoun_span, query, cand_text = sample prefix = sentence[:pronoun_span[0]].rstrip() suffix = sentence[pronoun_span[1]:] leading_space = ' ' if sentence[:pronoun_span[0]].endswith( ' ') else '' trailing_space = '' if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space, ) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_mask = self.binarize_with_mask( cand_text, prefix, suffix, leading_space, trailing_space, ) candidate_tokens.append(cand_toks) candidate_masks.append(cand_mask) candidate_lengths.append(cand_toks.size(0)) query_lengths = np.array(query_lengths) def get_pad_dataset_fn(tokens, length, pad_idx): return PadDataset( ListDataset(tokens, length), pad_idx=pad_idx, left_pad=False, ) query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad()) query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0) candidate_lengths = np.array(candidate_lengths) candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad()) candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) assert dataset is not None, "could not find dataset: {}".format( get_path(type, split)) return dataset src_tokens = make_dataset("input0", self.source_dictionary) pos_tokens = make_dataset("input1", self.pos_dictionary) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) label0_dataset = make_dataset("label0", self.label0_dictionary) label1_dataset = make_dataset("label1", self.label1_dictionary) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), pad_to_length=self._max_positions, ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "segments": { "seg_tokens": RightPadDataset( pos_tokens, pad_idx=self.pos_dictionary.pad(), pad_to_length=self._max_positions, ), "seg_lengths": NumelDataset(pos_tokens, reduce=False), }, "target0": RightPadDataset( # use 1 as padding, will be used to mask out padding when calculating loss ReplaceDataset( # replace eos and existing padding (used when some tokens should not be predicted) with -1 OffsetTokensDataset( # offset tokens to get the targets to the correct range (0,1,2,...) label0_dataset, offset=-self.label0_dictionary.nspecial, ), replace_map={ self.label0_dictionary.eos() - self.label0_dictionary.nspecial: -1, self.label0_dictionary.pad() - self.label0_dictionary.nspecial: -1, }, offsets=np.zeros(len(label0_dataset), dtype=np.int), ), pad_idx=-1, pad_to_length=self._max_positions, ), "target1": RightPadDataset( # use 1 as padding, will be used to mask out padding when calculating loss ReplaceDataset( # replace eos and existing padding (used when some tokens should not be predicted) with -1 OffsetTokensDataset( # offset tokens to get the targets to the correct range (0,1,2,...) label1_dataset, offset=-self.label1_dictionary.nspecial, ), replace_map={ self.label1_dictionary.eos() - self.label1_dictionary.nspecial: -1, self.label1_dictionary.pad() - self.label1_dictionary.nspecial: -1, }, offsets=np.zeros(len(label1_dataset), dtype=np.int), ), pad_idx=-1, pad_to_length=self._max_positions, ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat( [tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + ".jsonl") if not os.path.exists(data_path): raise FileNotFoundError("Cannot find data: {}".format(data_path)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] labels = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if "answerKey" in example: label = ord(example["answerKey"]) - ord("A") labels.append(label) question = example["question"]["stem"] assert len( example["question"]["choices"]) == self.args.num_classes # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` question = "Q: " + question question_toks = binarize(question, append_bos=True) for i, choice in enumerate(example["question"]["choices"]): src = "A: " + choice["text"] src_bin = torch.cat([question_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all( len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { "id": IdDataset(), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ "net_input{}".format(i + 1): { "src_tokens": RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), "src_lengths": src_lengths[i], } }) if len(labels) > 0: dataset.update({"target": RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[ np.maximum.reduce( [src_token.sizes for src_token in src_tokens]) ], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print("| Loaded {} with {} samples".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print('Loading dataset') src_datasets = [] tgt_datasets = [] for i in range(self.args.augmented_variants): data_path = os.path.join(self.args.data, str(i)) dataset_inst = data_utils.load_indexed_dataset( os.path.join(data_path, 'insts', split), self.instruction_dictionary, self.args.dataset_impl, combine=combine, ) dataset_state = data_utils.load_indexed_dataset( os.path.join(data_path, 'states', split), self.state_dictionary, self.args.dataset_impl, combine=combine, ) if dataset_inst is None or dataset_state is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) dataset_inst = SeqOfSeqDataset(dataset_inst, self.instruction_dictionary) dataset_state = SeqOfSeqDataset(dataset_state, self.state_dictionary) dataset_pos = IRPositionDataset( os.path.join(data_path, 'pos', split)) dataset = IRDataset(dataset_inst, dataset_state, dataset_pos) block_size = self.args.function_length print('| loaded {} batches from: {} and {}'.format( len(dataset), os.path.join(data_path, 'insts', split), os.path.join(data_path, 'states', split))) src_dataset, tgt_dataset = IRMaskTokensDataset.apply_mask( dataset, self.instruction_dictionary, self.state_dictionary, inst_pad_idx=self.instruction_dictionary.pad(), state_pad_idx=self.state_dictionary.pad(), inst_mask_idx=self.inst_mask_idx, state_mask_idx=self.state_mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, ) src_datasets.append(src_dataset) tgt_datasets.append(tgt_dataset) src_dataset = RandomChoiceMultipleDataset(src_datasets, seed=self.args.seed) tgt_dataset = RandomChoiceMultipleDataset(tgt_datasets, seed=self.args.seed, only_first=True) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src': IRPairPadDataset( src_dataset, inst_pad_idx=self.instruction_dictionary.pad(), state_pad_idx=self.state_dictionary.pad(), inst_mask_idx=self.inst_mask_idx, state_mask_idx=self.state_mask_idx, inst_cls_idx=self.instruction_dictionary.index( '<t>'), state_cls_idx=self.state_dictionary.index('<t>'), smallbert_insts_per_input=self.args. smallbert_insts_per_group, smallbert_states_per_input=self.args. smallbert_insts_per_group, max_length=block_size, inst_pad_length=32, state_pad_length=16, ) }, 'target': IRPadDataset( tgt_dataset, inst_pad_idx=self.instruction_dictionary.pad(), state_pad_idx=self.state_dictionary.pad(), inst_mask_idx=self.inst_mask_idx, state_mask_idx=self.state_mask_idx, inst_cls_idx=self.instruction_dictionary.index('<t>'), state_cls_idx=self.state_dictionary.index('<t>'), smallbert_insts_per_input=self.args. smallbert_insts_per_group, smallbert_states_per_input=self.args. smallbert_insts_per_group, max_length=block_size, inst_pad_length=32, state_pad_length=16, ), }, sizes=[src_dataset.sizes[:, 0]], ), sort_order=[ shuffle, src_dataset.sizes[:, 0], ], )