def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): src_dataset = RightPadDataset( TokenBlockDataset( src_tokens, src_lengths, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), pad_idx=self.source_dictionary.pad(), ) src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) src_dataset = NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": src_dataset, "src_lengths": NumelDataset(src_dataset, reduce=False), }, }, sizes=src_lengths, ) if sort: src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) return src_dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): src_dataset = PadDataset( TokenBlockDataset( src_tokens, src_lengths, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode='eos', ), pad_idx=self.source_dictionary.pad(), left_pad=False, ) src_dataset = PrependTokenDataset( src_dataset, self.source_dictionary.bos()) src_dataset = NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': src_dataset, 'src_lengths': NumelDataset(src_dataset, reduce=False), }, }, sizes=src_lengths, ) if sort: src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) return src_dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): src_dataset = RightPadDataset( TokenBlockDataset( src_tokens, src_lengths, self.args.tokens_per_sample, pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode='eos', ), pad_idx=self.source_dictionary.pad(), ) # remove tail src_dataset = RemoveTailDataset(src_dataset) src_dataset = NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': src_dataset, 'src_lengths': NumelDataset(src_dataset, reduce=False), }, }, sizes=src_lengths, ) if sort: src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) return src_dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We assume that the input begins with a bos symbol (`<s>`) and ends with an eos symbol (`</s>`). """ pad = self.source_dictionary.pad() eos = self.source_dictionary.eos() src_dataset = TokenBlockDataset( src_tokens, src_lengths, block_size=self.args.tokens_per_sample - 2, # for <s> and </s> pad=pad, eos=eos, break_mode=self.args.sample_break_mode, document_sep_len=0, ) prev_output_tokens = PrependTokenDataset( StripTokenDataset(src_dataset, eos), eos) src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": src_dataset, "src_lengths": NumelDataset(src_dataset, reduce=False), "prev_output_tokens": PadDataset(prev_output_tokens, pad_idx=pad, left_pad=False), }, "target": src_dataset, }, sizes=[np.array(src_lengths)], )
def build_dataset_for_inference(self, src_tokens, src_lengths, language="en_XX", **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens (or bos if `--add-bos-token` is set) and we append a <pad> to target. This is convenient both for generation with a prefix and LM scoring. """ dataset = StripTokenDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), # remove eos from (end of) target sequence self.source_dictionary.eos(), ) src_lang_idx = self.dictionary.index(lang_token(language)) src_dataset = PrependTokenDataset( dataset, token=((src_lang_idx or self.source_dictionary.bos()) if getattr( self.args, "add_bos_token", False) else self.source_dictionary.eos()), ) max_seq_len = max(src_lengths) + 1 tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_length=max_seq_len, ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_length=max_seq_len, ), }, sizes=[np.array(src_lengths)], )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ path = self.args.data + '.' + split tokens = [] starts = [] ends = [] unanswerables = [] lengths = [] try: data = from_records(path, self.args.max_seq_length) for inp, start, end, unanswerable in data: tokens.append(inp) lengths.append(len(inp)) starts.append(start) ends.append(end) unanswerables.append(unanswerable) except: data = [] tokens = BaseWrapperDataset(tokens) starts = BaseWrapperDataset(np.array(starts, dtype=np.long)) ends = BaseWrapperDataset(np.array(ends, dtype=np.long)) lengths = np.array(lengths, dtype=np.long) unanswerables = BaseWrapperDataset( np.array(unanswerables, dtype=np.float32)) print('| loaded {} batches from: {}'.format(len(lengths), path)) dataset = NestedDictionaryDataset( { 'id': IdDataset(), 'tokens': tokens, 'starts': starts, 'ends': ends, 'unanswerables': unanswerables, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(tokens, reduce=True), }, sizes=[lengths], ) self.datasets[split] = SortDataset( dataset, sort_order=[ np.random.permutation(len(lengths)), ], ) if self.args.do_shuffle else dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): assert not self.cfg.include_src or len(src_tokens[0]) == 2 input_src = None if self.cfg.include_src: input_src = TokenBlockDataset( [t[0] for t in src_tokens], [l[0] for l in src_lengths], block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ) input_src = PrependTokenDataset(input_src, self.dictionary.bos()) input_src = TruncateDataset(input_src, self.cfg.max_positions) input_tgt = TokenBlockDataset( [t[-1] for t in src_tokens], [l[-1] for l in src_lengths], block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ) input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) if self.cfg.include_src: src_tokens = ConcatSentencesDataset(input_src, input_tgt) src_lengths = NumelDataset(input_src, reduce=False) else: input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos()) src_tokens = input_tgt src_lengths = NumelDataset(src_tokens, reduce=False) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": src_lengths, }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } return NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], )
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): dataset = StripTokenDataset( TokenBlockDataset( src_tokens, src_lengths, block_size=None, # ignored for "eos" break mode pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode="eos", ), # remove eos from (end of) target sequence self.source_dictionary.eos(), ) src_dataset = PrependTokenDataset( dataset, token=(self.source_dictionary.bos() if getattr( self.args, "add_bos_token", False) else self.source_dictionary.eos()), ) tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) return NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": PadDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False), }, sizes=[np.array(src_lengths)], )
def load_dataset(self, split, combine=False, **kwargs): cache = os.path.join( self.args.data, "cached_{}_{}_{}.pth".format(split, self.args.bpe, self.args.max_seq_len)) if os.path.exists(cache): examples, features = torch.load(cache) else: if split == 'valid': examples = self.processor.get_dev_examples( self.args.data, self.train_or_dev_file[split]) else: examples = self.processor.get_train_examples( self.args.data, self.train_or_dev_file[split]) features = squad_convert_examples_to_features( examples=examples, tokenizer=self.tokenizer, max_seq_length=self.args.max_seq_len, doc_stride=128, max_query_length=64, is_training=(split != 'valid'), return_dataset=False, ) if self.args.distributed_rank == 0: torch.save((examples, features), cache) if split == 'valid' and self.do_evaluate: self.examples = examples self.features = features src_dataset = BaseWrapperDataset( [np.array(f.input_ids) for f in features]) starts = BaseWrapperDataset( np.array([f.start_position for f in features])) ends = BaseWrapperDataset(np.array([f.end_position for f in features])) sizes = np.array([len(f.input_ids) for f in features]) src_lengths = NumelDataset(src_dataset, reduce=False) ''' Input format: <s> question here ? </s> Passage </s> ''' dataset = NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': src_dataset, 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'targets': { 'starts': starts, 'ends': ends, }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[sizes], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, field, split): return os.path.join(self.args.data, type, field, split) def make_dataset(type, field, dictionary): split_path = get_path(type, field, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = {} input1 = {} for field in configs.fields: input0[field] = make_dataset('input0', field, self.source_dictionary[field]) assert input0[ field] is not None, 'could not find dataset: {}'.format( get_path('input0', field, split)) input1[field] = make_dataset('input1', field, self.source_dictionary[field]) assert input1[ field] is not None, 'could not find dataset: {}'.format( get_path('input1', field, split)) assert len(input0[field]) == len( input1[field]), 'input pair different length' if self.args.init_token is not None: input0[field] = PrependTokenDataset(input0[field], self.args.init_token) input1[field] = PrependTokenDataset(input1[field], self.args.init_token) if self.args.truncate_sequence: input0[field] = TruncateDataset(input0[field], self.args.max_positions) input1[field] = TruncateDataset(input1[field], self.args.max_positions) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(input0[field])) dataset = { 'id': IdDataset(), 'net_input0': { 'src_tokens': { field: RightPadDataset( input0[field], pad_idx=self.source_dictionary[field].pad()) for field in configs.fields }, 'src_lengths': NumelDataset(input0[field], reduce=False), }, 'net_input1': { 'src_tokens': { field: RightPadDataset( input1[field], pad_idx=self.source_dictionary[field].pad()) for field in configs.fields }, 'src_lengths': NumelDataset(input1[field], reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens0': NumelDataset(input0[field], reduce=True), 'ntokens1': NumelDataset(input1[field], reduce=True), } label_path = "{0}.label".format(get_path('label', '', split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum(input0[field].sizes, input1[field].sizes)], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} batches from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets if self.args.mask_whole_words: bpe = encoders.build_bpe(self.args) if bpe is not None: def is_beginning_of_word(i): if i < self.source_dictionary.nspecial: # special elements are always considered beginnings return True tok = self.source_dictionary[i] if tok.startswith('madeupword'): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True mask_whole_words = torch.ByteTensor(list( map(is_beginning_of_word, range(len(self.source_dictionary))) )) else: mask_whole_words = None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if not self.args.regression_target: label_dataset = make_dataset('label', self.target_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.target_dictionary.eos(), ), offset=-self.target_dictionary.nspecial, )) else: label_path = f"{get_path('label', split)}.label" if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print(f"| Loaded {split} with #samples: {len(dataset)}") self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) input_options = [ make_dataset('input{idx}'.format(idx=idx + 1), self.source_dictionary) for idx in range(self.args.num_classes) ] if self.args.separator_token is not None: input0 = PrependTokenDataset(input0, self.args.separator_token) src_tokens = [] for input_option in input_options: if self.args.init_token is not None: input_option = PrependTokenDataset(input_option, self.args.init_token) if self.args.max_option_length is not None: input_option = TruncateDataset(input_option, self.args.max_option_length) src_token = ConcatSentencesDataset(input_option, input0) if self.args.truncate_sequence: src_token = TruncateDataset(src_token, self.args.max_positions) src_tokens.append(src_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens[0])) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for src_token_idx in range(len(src_tokens)): dataset.update({ 'net_input{idx}'.format(idx=src_token_idx + 1): { 'src_tokens': RightPadDataset( src_tokens[src_token_idx], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False), } }) label_path = '{}.label'.format(get_path('label', split)) if os.path.exists(label_path): with open(label_path) as h: dataset.update(target=RawLabelDataset( [int(x.strip()) for x in h.readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[ np.maximum.reduce( [src_token.sizes for src_token in src_tokens]) ], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s: str, append_eos: bool = False): if self.tokenizer is not None: s = self.tokenizer.encode(s) if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=append_eos, add_if_not_exist=False, ).long() if self.args.init_token is not None: tokens = torch.cat( [tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator( data_path): prefix = sentence[:pronoun_span.start].text suffix = sentence[pronoun_span.end:].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE leading_space = ' ' if sentence[:pronoun_span. start].text_with_ws.endswith( ' ') else '' trailing_space = ' ' if pronoun_span.text_with_ws.endswith( ' ') else '' # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( wsc_utils.extended_noun_chunks(sentence), exclude_pronouns=True, exclude_query=query, exact_match=False, ) def binarize_with_mask(txt): toks = binarize( prefix + leading_space + txt + trailing_space + suffix, append_eos=True, ) mask = torch.zeros_like(toks, dtype=torch.uint8) mask_start = len(binarize(prefix)) mask_size = len(binarize(leading_space + txt)) mask[mask_start:mask_start + mask_size] = 1 return toks, mask if query is not None: query_toks, query_mask = binarize_with_mask(query) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = binarize_with_mask(cand_span.text) cand_toks.append(toks) cand_masks.append(mask) # collate candidates cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) assert cand_toks.size() == cand_masks.size() candidate_tokens.append(cand_toks) candidate_masks.append(cand_masks) candidate_lengths.append(cand_toks.size(1)) labels.append(label) query_lengths = np.array(query_lengths) query_tokens = ListDataset(query_tokens, query_lengths) query_masks = ListDataset(query_masks, query_lengths) candidate_lengths = np.array(candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) labels = ListDataset(labels, [1] * len(labels)) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'labels': labels, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.args.mask_multiple_length, mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset src_tokens = make_dataset('data', self.source_dictionary) if self.args.init_token is not None: src_tokens = PrependTokenDataset(src_tokens, self.args.init_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } label_dataset = make_dataset('label', self.label_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, )) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, self.args.shorten_data_split_whitelist, self.args.shorten_method, self.args.max_positions, self.args.seed, ) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset['net_input'].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset('label', self.label_dictionary) if label_dataset is not None: dataset.update( target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, ) ) else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert len(values) == self.args.num_classes, \ f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] dataset.update( target=RawLabelDataset([ parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines()) ]) ) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(key, split): return os.path.join(self.cfg.data, key, split) def make_dataset(key, dictionary): split_path = get_path(key, split) try: dataset = data_utils.load_indexed_dataset( split_path, dictionary, combine=combine, ) except Exception as e: if "StorageException: [404] Path not found" in str(e): logger.warning(f"dataset {e} not found") dataset = None else: raise e return dataset input0 = make_dataset("input0", self.source_dictionary) assert input0 is not None, "could not find dataset: {}".format( get_path("input0", split)) input1 = make_dataset("input1", self.source_dictionary) if self.cfg.init_token is not None: input0 = PrependTokenDataset(input0, self.cfg.init_token) if input1 is None: src_tokens = input0 else: if self.cfg.separator_token is not None: input1 = PrependTokenDataset(input1, self.cfg.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.cfg.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, self.cfg.shorten_data_split_list, self.cfg.shorten_method, self.max_positions(), self.cfg.seed, ) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } if self.cfg.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset["net_input"].update( prev_output_tokens=prev_tokens_dataset, ) if not self.cfg.regression_target: label_dataset = make_dataset("label", self.label_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.label_dictionary.eos(), ), offset=-self.label_dictionary.nspecial, )) else: label_path = "{0}.label".format(get_path("label", split)) if os.path.exists(label_path): def parse_regression_target(i, line): values = line.split() assert ( len(values) == self.cfg.num_classes ), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] with open(label_path) as h: dataset.update(target=RawLabelDataset([ parse_regression_target(i, line.strip()) for i, line in enumerate(h.readlines()) ])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.cfg.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" if self.cfg.data.endswith("1"): data_shard = (epoch - 1) % self.cfg.num_data_splits + 1 data_path = self.cfg.data[:-1] + str(data_shard) else: data_path = self.cfg.data def get_path(type, data_split): return os.path.join(data_path, str(type), data_split) def make_dataset(type, dictionary, data_split, combine): split_path = get_path(type, data_split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, combine=combine, ) return dataset def load_split(data_split, metric): input_src = None if self.cfg.include_src: input_src = make_dataset("input_src", self.dictionary, data_split, combine=False) assert input_src is not None, "could not find dataset: {}".format( get_path("input_src", data_split)) input_tgt = make_dataset("input_tgt", self.dictionary, data_split, combine=False) assert input_tgt is not None, "could not find dataset: {}".format( get_path("input_tgt", data_split)) label_path = f"{get_path(metric, data_split)}.{metric}" assert os.path.exists( label_path), f"could not find dataset: {label_path}" np_labels = np.loadtxt(label_path) if self.cfg.target_metric == "ter": np_labels = -np_labels label = RawLabelDataset(np_labels) return input_src, input_tgt, label src_datasets = [] tgt_datasets = [] label_datasets = [] if split == self.cfg.train_subset: for k in itertools.count(): split_k = "train" + (str(k) if k > 0 else "") prefix = os.path.join(data_path, "input_tgt", split_k) if not indexed_dataset.dataset_exists(prefix, impl=None): if k > 0: break else: raise FileNotFoundError(f"Dataset not found: {prefix}") input_src, input_tgt, label = load_split( split_k, self.cfg.target_metric) src_datasets.append(input_src) tgt_datasets.append(input_tgt) label_datasets.append(label) else: input_src, input_tgt, label = load_split(split, self.cfg.target_metric) src_datasets.append(input_src) tgt_datasets.append(input_tgt) label_datasets.append(label) if len(tgt_datasets) == 1: input_tgt, label = tgt_datasets[0], label_datasets[0] if self.cfg.include_src: input_src = src_datasets[0] else: input_tgt = ConcatDataset(tgt_datasets) label = ConcatDataset(label_datasets) if self.cfg.include_src: input_src = ConcatDataset(src_datasets) input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) if self.cfg.include_src: input_src = PrependTokenDataset(input_src, self.dictionary.bos()) input_src = TruncateDataset(input_src, self.cfg.max_positions) src_lengths = NumelDataset(input_src, reduce=False) src_tokens = ConcatSentencesDataset(input_src, input_tgt) else: src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos()) src_lengths = NumelDataset(src_tokens, reduce=False) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": src_lengths, }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "target": label, } dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) assert len(dataset) % self.cfg.mt_beam == 0, ( "dataset size (%d) is not a multiple of beam size (%d)" % (len(dataset), self.cfg.mt_beam)) # no need to shuffle valid/test sets if not self.cfg.no_shuffle and split == self.cfg.train_subset: # need to keep all hypothese together start_idx = np.arange(0, len(dataset), self.cfg.mt_beam) with data_utils.numpy_seed(self.cfg.seed + epoch): np.random.shuffle(start_idx) idx = np.arange(0, self.cfg.mt_beam) shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile( start_idx, (self.cfg.mt_beam, 1)).transpose().reshape(-1) dataset = SortDataset( dataset, sort_order=[shuffle], ) logger.info(f"Loaded {split} with #samples: {len(dataset)}") self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def getIns(bped,bpeTokens,tokens,L,R): resL=0 tkL=" ".join(tokens[:L]) bped_tkL=self.bpe.encode(tkL) if bped.find(bped_tkL)==0: resL=len(bped_tkL.split()) else: tkL+=" " bped_tkL=self.bpe.encode(tkL) if bped.find(bped_tkL)==0: resL=len(bped_tkL.split()) resR=0 tkR=" ".join(tokens[R:]) bped_tkR=self.bpe.encode(tkR) if bped.rfind(bped_tkR)+len(bped_tkR)==len(bped): resR=len(bpeTokens)-len(bped_tkR.split()) else: tkR=" "+tkR bped_tkR=self.bpe.encode(tkR) if bped.rfind(bped_tkR)+len(bped_tkR)==len(bped): resR=len(bpeTokens)-len(bped_tkR.split()) return resL, resR def getExample(a,bias): s=" ".join(a["token"]) ss=self.bpe.encode(s) sst=ss.split() headL=a['h']['pos'][0] headR=a['h']['pos'][1] hiL, hiR=getIns(ss,sst,a["token"],headL,headR) tailL=a['t']['pos'][0] tailR=a['t']['pos'][1] tiL, tiR=getIns(ss,sst,a["token"],tailL,tailR) E1b='1' E1e='2' E2b='3' E2e='4' ins=[(hiL, E1b), (hiR, E1e), (tiL, E2b), (tiR, E2e)] ins=sorted(ins) pE1=0 pE2=0 pE1_=0 pE2_=0 for i in range(0,4): sst.insert(ins[i][0]+i,ins[i][1]) if ins[i][1]==E1b: pE1=ins[i][0]+i elif ins[i][1]==E2b: pE2=ins[i][0]+i elif ins[i][1]==E1e: pE1_=ins[i][0]+i else: pE2_=ins[i][0]+i if pE1_-pE1==1 or pE2_-pE2==1: return "???", -1, -1 else: return " ".join(sst), pE1+bias, pE2+bias def get_example_bert(item): if 'text' in item: sentence = item['text'] is_token = False else: sentence = item['token'] is_token = True pos_head = item['h']['pos'] pos_tail = item['t']['pos'] pos_min = pos_head pos_max = pos_tail if pos_head[0] > pos_tail[0]: pos_min = pos_tail pos_max = pos_head rev = True else: rev = False if not is_token: sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) else: sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]])) ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]])) sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]])) ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]])) sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:])) ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]'] ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]'] re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]'] pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1) pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0) #pos1 = min(self.max_length - 1, pos1) #pos2 = min(self.max_length - 1, pos2) indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) avai_len = len(indexed_tokens) # Position #pos1 = torch.tensor([[pos1]]).long() #pos2 = torch.tensor([[pos2]]).long() #indexed_tokens = indexed_tokens[:self.max_length] indexed_tokens = torch.tensor(indexed_tokens).long() return indexed_tokens, pos1, pos2 def binarize(s, append_bos=False): #if self.bpe is not None: # s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') rel2id_path=os.path.join(self.args.data, "rel2id.json") if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) if not os.path.exists(rel2id_path): raise FileNotFoundError('Cannot find rel2id: {}'.format(rel2id_path)) rel2id=json.load(open(rel2id_path,"r")) labels = [] src_tokens = [] src_lengths = [] src_idx = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if 'relation' in example: label = rel2id[example['relation']] labels.append(label) #bped=self.bpe.encode(" ".join(example["token"])) if getattr(self.args, 'bert', False): src_bin, pE1, pE2 = get_example_bert(example) else: bped, pE1, pE2 = getExample(example,1) if pE1==-1: continue src_bin = binarize(bped, append_bos=True) src_tokens.append(src_bin) src_lengths.append(len(src_bin)) #pE1=0 #pE2=0 src_idx.append([[pE1 for i in range(0,self.args.encoder_embed_dim)], [pE2 for i in range(0,self.args.encoder_embed_dim)]]) src_lengths = np.array(src_lengths) src_tokens = ListDataset(src_tokens, src_lengths) src_lengths = ListDataset(src_lengths) print("src_len", len(src_lengths)) print("src_tokens", len(src_tokens)) dataset = { 'id': IdDataset(), 'net_input':{ 'src_tokens':RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad() ), 'src_lengths': src_lengths, }, 'index': RawLabelDataset(src_idx), 'target': RawLabelDataset(labels), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } dataset = NestedDictionaryDataset( dataset, sizes=src_tokens.sizes, ) with data_utils.numpy_seed(self.args.seed+epoch): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) input0 = PrependTokenDataset(input0, self.source_dictionary.bos()) if input1 is None: src_tokens = input0 else: input1 = PrependTokenDataset(input1, self.source_dictionary.eos()) src_tokens = ConcatSentencesDataset(input0, input1) src_tokens = TruncateDataset(src_tokens, self.args.max_positions) assert not self.args.mask_whole_words src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( src_tokens, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=None, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadToLenDataset(src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_len=self.args.max_positions), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadToLenDataset(tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, pad_len=self.args.max_positions), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + ".jsonl") if not os.path.exists(data_path): raise FileNotFoundError("Cannot find data: {}".format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test")) for sample in itr: sentence, pronoun_span, query, cand_text = sample prefix = sentence[:pronoun_span[0]].rstrip() suffix = sentence[pronoun_span[1]:] leading_space = " " if sentence[:pronoun_span[0]].endswith( " ") else "" trailing_space = "" if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space, ) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_mask = self.binarize_with_mask( cand_text, prefix, suffix, leading_space, trailing_space, ) candidate_tokens.append(cand_toks) candidate_masks.append(cand_mask) candidate_lengths.append(cand_toks.size(0)) query_lengths = np.array(query_lengths) def get_pad_dataset_fn(tokens, length, pad_idx): return PadDataset( ListDataset(tokens, length), pad_idx=pad_idx, left_pad=False, ) query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad()) query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0) candidate_lengths = np.array(candidate_lengths) candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad()) candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0) dataset = { "id": IdDataset(), "query_tokens": query_tokens, "query_masks": query_masks, "candidate_tokens": candidate_tokens, "candidate_masks": candidate_masks, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + ".jsonl") if not os.path.exists(data_path): raise FileNotFoundError("Cannot find data: {}".format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator( data_path): prefix = sentence[:pronoun_span.start].text suffix = sentence[pronoun_span.end:].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE leading_space = ( " " if sentence[:pronoun_span.start].text_with_ws.endswith(" ") else "") trailing_space = " " if pronoun_span.text_with_ws.endswith( " ") else "" # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( wsc_utils.extended_noun_chunks(sentence), exclude_pronouns=True, exclude_query=query, exact_match=False, ) if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = self.binarize_with_mask( cand_span.text, prefix, suffix, leading_space, trailing_space, ) cand_toks.append(toks) cand_masks.append(mask) # collate candidates cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) assert cand_toks.size() == cand_masks.size() candidate_tokens.append(cand_toks) candidate_masks.append(cand_masks) candidate_lengths.append(cand_toks.size(1)) labels.append(label) query_lengths = np.array(query_lengths) query_tokens = ListDataset(query_tokens, query_lengths) query_masks = ListDataset(query_masks, query_lengths) candidate_lengths = np.array(candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) labels = ListDataset(labels, [1] * len(labels)) dataset = { "id": IdDataset(), "query_tokens": query_tokens, "query_masks": query_masks, "candidate_tokens": candidate_tokens, "candidate_masks": candidate_masks, "labels": labels, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split: str, combine: bool = False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / "{split}".format(split=split) src_tokens = data_utils.load_indexed_dataset( str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine, ) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) targets_path = Path(self.args.data) / "{}.term".format(split) term_labels = data_utils.load_indexed_dataset( str(targets_path), self.label_dictionary, self.args.dataset_impl, combine=combine, ) assert term_labels is not None, "could not find labels: {}".format( targets_path) term_cats, term_attrs = POSDataset.make_both(term_labels, self.dictionary, self.label_dictionary) def print_terms(term_cats, term_attrs): # Debug function cat_labels = [self.label_dictionary[t] for t in term_cats] attr_data = [t.nonzero().T for t in term_attrs if t.numel()] attr_labels = [] for word_attr in attr_data: if not word_attr.numel(): attr_labels.append([]) continue attr_labels.append([ self.label_dictionary[t + self.label_dictionary.nspecial] for t in word_attr[0] ]) return cat_labels, attr_labels word_mask = WordEndMaskDataset(src_tokens, self.dictionary, self.is_word_initial, bos_value=0, eos_value=0) exclude_cats_mask = IgnoreLabelsDataset(term_cats, self.ignore_cats) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "word_mask": RightPadDataset(word_mask, pad_idx=0), }, "exclude_cats_mask": RightPadDataset(exclude_cats_mask, pad_idx=1), "target_cats": RightPadDataset(term_cats, pad_idx=self.label_dictionary.pad()), "target_attrs": RightPad2dDataset(term_attrs, pad_idx=0), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_KE_dataset(self, split, kedata_path, epoch=0, combine=False): paths = kedata_path.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] def get_path(type): return os.path.join(data_path,type,split) def desc_dataset(type, dictionary, relation_desc=None): now_path=get_path(type) #print(now_path) dataset=data_utils.load_indexed_dataset( now_path, dictionary, self.args.dataset_impl, combine=combine, ) if self.args.init_token is not None: dataset = PrependTokenDataset(dataset, self.args.init_token) if relation_desc is not None: dataset = ConcatSentencesDataset(dataset, relation_desc) dataset = TruncateDataset(dataset, self.args.tokens_per_sample) #??? dataset = RightPadDataset(dataset, pad_idx=self.source_dictionary.pad()) return dataset assert(not (self.args.relation_desc and self.args.relemb_from_desc)) if self.args.relation_desc or self.args.relemb_from_desc: now_path=get_path('relation_desc') relation_desc=data_utils.load_indexed_dataset( now_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if self.args.relation_desc: if self.args.separator_token is not None: relation_desc = PrependTokenDataset(relation_desc, self.args.separator_token) else: raise Exception("separator_token is None") elif self.args.relemb_from_desc: relation_desc = PrependTokenDataset(relation_desc, self.args.init_token) relation_desc = TruncateDataset(relation_desc, self.args.tokens_per_sample // 8) # 64 relation_desc = RightPadDataset(relation_desc, pad_idx=self.source_dictionary.pad()) else: relation_desc = None head=desc_dataset("head",self.source_dictionary) tail=desc_dataset("tail",self.source_dictionary) nHead=desc_dataset("negHead",self.source_dictionary) nTail=desc_dataset("negTail",self.source_dictionary) head_r=desc_dataset("head",self.source_dictionary, relation_desc if self.args.relation_desc else None) tail_r=desc_dataset("tail",self.source_dictionary, relation_desc if self.args.relation_desc else None) assert len(nHead)%len(head)==0, "check the KE positive and negative instances' number" self.negative_sample_size=len(nHead)/len(head) relation=np.load(get_path("relation")+".npy") sizes=np.load(get_path("sizes")+".npy") with data_utils.numpy_seed(self.args.seed + epoch): shuffle=np.random.permutation(len(head)) net_input = { 'heads': head, 'tails': tail, 'nHeads': KeNegDataset(nHead,self.args), 'nTails': KeNegDataset(nTail,self.args), 'heads_r': head_r, 'tails_r': tail_r, 'src_lengths': FakeNumelDataset(sizes, reduce=False), } if self.args.relemb_from_desc: net_input['relation_desc'] = relation_desc dataset=SortDataset( NestedDictionaryDataset( { 'id':IdDataset(), 'net_input': net_input, 'target': RawLabelDataset(relation), 'nsentences':NumSamplesDataset(), 'ntokens': FakeNumelDataset(sizes, reduce=True), }, sizes=[sizes], ), sort_order=[shuffle], ) return dataset
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print("Split type --> " + str(split)) def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] labels = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if 'answerKey' in example: label = ord(example['answerKey']) - ord('A') labels.append(label) question = example['question']['stem'] if(self.args.num_classes != len(example['question']['choices'])): print("Class size = " + str(self.args.num_classes) + ". Length of sample size = " + str(len(example['question']['choices']))) assert len(example['question']['choices']) == self.args.num_classes # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` question = 'Q: ' + question question_toks = binarize(question, append_bos=True) for i, choice in enumerate(example['question']['choices']): src = 'A: ' + choice['text'] src_bin = torch.cat([question_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split: str, combine: bool = False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / "{split}".format(split=split) src_tokens = data_utils.load_indexed_dataset(str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) targets_path = Path(self.args.data) / "{}.term".format(split) labels = data_utils.load_indexed_dataset(str(targets_path), self._label_dictionary, self.args.dataset_impl, combine=combine) assert labels is not None, "could not find labels: {}".format( targets_path) clean_labels = NoBosEosDataset(labels, self.label_dictionary) word_mask = WordEndMaskDataset(src_tokens, self.dictionary, self.is_word_initial, bos_value=0, eos_value=0) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "word_mask": RightPadDataset(word_mask, pad_idx=0), # pad is zero since mask }, "target_attrs": RightPadDataset(clean_labels, pad_idx=self.label_dictionary.pad()), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split: str, combine: bool = False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" inputs_path = Path(self.args.data) / f"{split}.text" src_tokens = data_utils.load_indexed_dataset( str(inputs_path), self.source_dictionary, self.args.dataset_impl, combine=combine, ) assert src_tokens is not None, "could not find dataset: {}".format( inputs_path) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = PrependTokenDataset(src_tokens, self.source_dictionary.bos()) word_masks_w_bos = WordEndMaskDataset(src_tokens, self.dictionary, self.is_word_initial, bos_value=1, eos_value=0) nterm_targets_path = Path(self.args.data) / "{}.nonterm".format(split) labelled_spans = data_utils.load_indexed_dataset( str(nterm_targets_path), self.nterm_dictionary, self.args.dataset_impl, combine=combine, ) assert labelled_spans is not None, "could not find nonterminal labels: {}".format( nterm_targets_path) target_spans, nterm_cats = DynamicLabelledSpanDataset.make_both( labelled_spans, self.nterm_dictionary, seed=self.args.seed, ) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset(src_tokens, pad_idx=self.source_dictionary.pad()), "nsrc_tokens": NumelDataset(src_tokens), "word_mask_w_bos": RightPadDataset(word_masks_w_bos, pad_idx=0), }, "target_span_labels": RightPadDataset(nterm_cats, pad_idx=self.nterm_dictionary.pad()), "target_spans": RightPadDataset(target_spans, pad_idx=0), "ntarget_span_labels": NumelDataset(nterm_cats), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), "nwords": NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial), } nested_dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes]) dataset = SortDataset(nested_dataset, sort_order=[shuffle]) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print('Loading dataset') data_path = os.path.join(self.args.data) dataset_inst = data_utils.load_indexed_dataset( os.path.join(data_path, 'insts', split), self.instruction_dictionary, self.args.dataset_impl, combine=combine, ) dataset_state = data_utils.load_indexed_dataset( os.path.join(data_path, 'states', split), self.state_dictionary, self.args.dataset_impl, combine=combine, ) if dataset_inst is None or dataset_state is None: raise FileNotFoundError('Dataset not found: {}'.format(split)) dataset_inst = SeqOfSeqDataset(dataset_inst, self.instruction_dictionary) dataset_state = SeqOfSeqDataset(dataset_state, self.state_dictionary) dataset_pos = IRPositionDataset(os.path.join(data_path, 'pos', split)) dataset = IRDataset(dataset_inst, dataset_state, dataset_pos) block_size = self.args.function_length dataset = IRPadDataset( dataset, inst_pad_idx=self.instruction_dictionary.pad(), state_pad_idx=self.state_dictionary.pad(), inst_mask_idx=self.inst_mask_idx, state_mask_idx=self.state_mask_idx, inst_cls_idx=self.instruction_dictionary.index('<t>'), state_cls_idx=self.state_dictionary.index('<t>'), smallbert_insts_per_input=self.args.smallbert_insts_per_group, smallbert_states_per_input=self.args.smallbert_insts_per_group, max_length=block_size, inst_pad_length=32, state_pad_length=16, pair=True, ) labels_str = list(map(json.loads, open(os.path.join(data_path, 'label', split + '.txt')))) labels = torch.tensor([x - 1 if isinstance(x, int) else int(x.strip()) - 1 for x in labels_str]) print('| loaded {} batches from: {} and {}'.format(len(dataset), os.path.join(data_path, 'insts', split), os.path.join(data_path, 'states', split))) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.labels[split] = SortDataset(RawLabelDataset(labels), sort_order=[shuffle]) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src': dataset, }, 'label': RawLabelDataset(labels) }, sizes=[dataset.sizes], ), sort_order=[ shuffle, # dataset.sizes, ], )