def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch. Default: ``None`` max_sentences (int, optional): max number of sentences in each batch. Default: ``None`` max_positions (optional): max sentence length supported by the model. Default: ``None`` ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long. Default: ``False`` required_batch_size_multiple (int, optional): require batch size to be a multiple of N. Default: ``1`` seed (int, optional): seed for random number generator for reproducibility. Default: ``1`` num_shards (int, optional): shard the data iterator into N shards. Default: ``1`` shard_id (int, optional): which shard of the data iterator to return. Default: ``0`` Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ assert isinstance(dataset, FairseqDataset) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # filter examples that are too large indices = data_utils.filter_by_size( indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs), ) # create mini-batches with given size constraints batch_sampler = data_utils.batch_by_size( indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) # return a reusable, sharded iterator return iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path(type, split)) input1 = make_dataset('input1', self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset['net_input'].update( prev_output_tokens=prev_tokens_dataset, ) if not self.args.regression_target: label_dataset = make_dataset('label', self.target_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.target_dictionary.eos(), ), offset=-self.target_dictionary.nspecial, )) else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample; periodically print out randomly sampled predictions if model is in training mode, otherwise aggregate word error stats for validation. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ dict = self.scorer.dict if model.training: if ((len(self.args.scheduled_sampling_probs) > 1 or self.args.scheduled_sampling_probs[0] < 1.0) and self.epoch >= self.args.start_scheduled_sampling_epoch): # scheduled sampling ss_prob = self.args.scheduled_sampling_probs[min( self.epoch - self.args.start_scheduled_sampling_epoch, len(self.args.scheduled_sampling_probs) - 1)] assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} encoder_input = { k: v for k, v in sample['net_input'].items() if k != 'prev_output_tokens' } encoder_out = model.encoder(**encoder_input) target = sample['target'] tokens = sample['net_input']['prev_output_tokens'] lprobs = [] pred = None for step in range(target.size(1)): if step > 0: sampling_mask = torch.rand( [target.size(0), 1], device=target.device, ).lt(ss_prob) feed_tokens = torch.where( sampling_mask, tokens[:, step:step + 1], pred, ) else: feed_tokens = tokens[:, step:step + 1] log_probs, _ = self._decode( feed_tokens, model, encoder_out, incremental_states, ) pred = log_probs.argmax(-1, keepdim=True) lprobs.append(log_probs) lprobs = torch.stack(lprobs, dim=1) else: # normal training net_output = model(**sample['net_input']) lprobs = model.get_normalized_probs(net_output, log_probs=True) target = model.get_targets(sample, net_output) else: assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} encoder_input = { k: v for k, v in sample['net_input'].items() if k != 'prev_output_tokens' } encoder_out = model.encoder(**encoder_input) target = sample['target'] # make the maximum decoding length equal to at least the length of # target, and the length of encoder_out if possible maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) tokens[:, 0] = dict.eos() lprobs = [] attn = [] if getattr(model.decoder, 'need_attn', False) else None dummy_log_probs = encoder_out['encoder_out'][0].new_full( [target.size(0), len(dict)], -np.log(len(dict))) for step in range(maxlen + 1): # one extra step for EOS marker is_eos = tokens[:, step].eq(dict.eos()) # if all predictions are finished (i.e., ended with eos), # pad lprobs to target length with dummy log probs, # truncate tokens up to this step and break if step > 0 and is_eos.sum() == is_eos.size(0): for _ in range(step, target.size(1)): lprobs.append(dummy_log_probs) tokens = tokens[:, :step + 1] break log_probs, attn_scores = self._decode( tokens[:, :step + 1], model, encoder_out, incremental_states, ) tokens[:, step + 1] = log_probs.argmax(-1) if step > 0: # deal with finished predictions # make log_probs uniform if the previous output token is EOS # and add consecutive EOS to the end of prediction log_probs[is_eos, :] = -np.log(log_probs.size(1)) tokens[is_eos, step + 1] = dict.eos() if step < target.size(1): lprobs.append(log_probs) if getattr(model.decoder, 'need_attn', False): attn.append(attn_scores) # bsz x min(tgtlen, maxlen + 1) x vocab_size lprobs = torch.stack(lprobs, dim=1) if getattr(model.decoder, 'need_attn', False): # bsz x (maxlen + 1) x (length of encoder_out) attn = torch.stack(attn, dim=1) # word error stats code starts if (not model.training or (self.num_updates // self.args.print_interval > (self.num_updates - 1) // self.args.print_interval)): pred = lprobs.argmax(-1).cpu() if model.training else \ tokens[:, 1:].data.cpu() # bsz x len if not model.training: # validation step, compute WER stats with scorer assert pred.size(0) == target.size(0) self.scorer.reset() for i in range(target.size(0)): utt_id = sample['utt_id'][i] id = sample['id'].data[i].item() # ref_tokens = dict.string(target.data[i]) # if it is a dummy batch (e.g., a "padding" batch in a sharded # dataset), id might exceeds the dataset size; in that case we # just skip it if id < len(self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens( id) pred_tokens = dict.string(pred.data[i]) self.scorer.add_evaluation( utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, ) else: # print a randomly sampled result every print_interval updates assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) # ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) ref_one = self.train_tgt_dataset.get_original_text( id, dict, bpe_symbol=self.args.remove_bpe, ) pred_one = dict.tokens_to_sentence( dict.string(pred.data[i][:length]), bpe_symbol=self.args.remove_bpe, ) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends lprobs = lprobs.view(-1, lprobs.size(-1)) loss = F.nll_loss( lprobs, target.view(-1), ignore_index=self.padding_idx, reduction='sum' if reduce else 'none', ) sample_size = sample['target'].size( 0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, 'nll_loss': utils.item(loss.data) if reduce else loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } if not model.training: # do not compute word error in training mode logging_output['word_error'] = self.scorer.tot_word_error() logging_output['word_count'] = self.scorer.tot_word_count() logging_output['char_error'] = self.scorer.tot_char_error() logging_output['char_count'] = self.scorer.tot_char_count() return loss, sample_size, logging_output
def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). (we do not use it anymore and must ensure that everywhere) max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). (we do not use it anymore and must ensure that everywhere) ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N shards (default: 1). shard_id (int, optional): which shard of the data iterator to return (default: 0). num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 0). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ assert isinstance(dataset, FairseqDataset) # get indices ordered by example size #this should be already fixed in the dataset with data_utils.numpy_seed(seed): #will get our ordered_indices #will be filtering by our size indices = dataset.ordered_indices() # filter_by_size was removed as we believe we do not need it (Christine)(18-12-2019) # create mini-batches which has batches with given size constraints # it just adjusts the batch by size (works by batch_by_size implemented above) # should be tested later if it suits the framework(Christine)(18-12-2019) batch_sampler = batch_by_size(indices, max_sentences=max_sentences) # batches should be here returned correctly, mini batches should be ??? # return a reusable, sharded iterator return iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, )
def shuffle_batches(batches, seed): with data_utils.numpy_seed(seed): np.random.shuffle(batches) return batches
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] labels = [] for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator( data_path): prefix = sentence[:pronoun_span.start].text suffix = sentence[pronoun_span.end:].text_with_ws # spaCy spans include trailing spaces, but we need to know about # leading spaces for the GPT-2 BPE leading_space = ' ' if sentence[:pronoun_span. start].text_with_ws.endswith( ' ') else '' trailing_space = ' ' if pronoun_span.text_with_ws.endswith( ' ') else '' # get noun phrases, excluding pronouns and anything overlapping with the query cand_spans = wsc_utils.filter_noun_chunks( wsc_utils.extended_noun_chunks(sentence), exclude_pronouns=True, exclude_query=query, exact_match=False, ) if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_masks = [], [] for cand_span in cand_spans: toks, mask = self.binarize_with_mask( cand_span.text, prefix, suffix, leading_space, trailing_space, ) cand_toks.append(toks) cand_masks.append(mask) # collate candidates cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad()) cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0) assert cand_toks.size() == cand_masks.size() candidate_tokens.append(cand_toks) candidate_masks.append(cand_masks) candidate_lengths.append(cand_toks.size(1)) labels.append(label) query_lengths = np.array(query_lengths) query_tokens = ListDataset(query_tokens, query_lengths) query_masks = ListDataset(query_masks, query_lengths) candidate_lengths = np.array(candidate_lengths) candidate_tokens = ListDataset(candidate_tokens, candidate_lengths) candidate_masks = ListDataset(candidate_masks, candidate_lengths) labels = ListDataset(labels, [1] * len(labels)) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'labels': labels, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if cfg.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if cfg.no_save: return trainer.consolidate_optimizer( ) # TODO(SS): do we need this if no_save_optimizer_state if not trainer.should_save_checkpoint_on_current_rank: if trainer.always_call_state_dict_during_save_checkpoint: trainer.state_dict() return write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() logger.info( f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format( epoch, suffix)] = (end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0) checkpoint_conds["checkpoint_{}_{}{}.pt".format( epoch, updates, suffix)] = (not end_of_epoch and cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0) checkpoint_conds["checkpoint_best{}.pt".format( suffix)] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)) if val_loss is not None and cfg.keep_best_checkpoints > 0: worst_best = getattr(save_checkpoint, "best", None) chkpts = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( cfg.best_checkpoint_metric, suffix), ) if len(chkpts) > 0: p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] worst_best = float( p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) # add random digits to resolve ties with data_utils.numpy_seed(epoch, updates, val_loss): rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) checkpoint_conds["checkpoint.best_{}_{:.3f}{}{}.pt".format( cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix)] = worst_best is None or is_better(val_loss, worst_best) checkpoint_conds["checkpoint_last{}.pt".format( suffix)] = not cfg.no_last_checkpoints extra_state = { "train_iterator": epoch_itr.state_dict(), "val_loss": val_loss } if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: if cfg.write_checkpoints_asynchronously: # TODO[ioPath]: Need to implement a delayed asynchronous # file copying/moving feature. logger.warning( f"ioPath is not copying {checkpoints[0]} to {cp} " "since async write mode is on.") else: assert PathManager.copy( checkpoints[0], cp, overwrite=True), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)" .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum)) if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order if cfg.keep_interval_updates_pattern == -1: checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)) else: checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), keep_match=True, ) checkpoints = [ x[0] for x in checkpoints if x[1] % cfg.keep_interval_updates_pattern != 0 ] for old_chk in checkpoints[cfg.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) elif PathManager.exists(old_chk): PathManager.rm(old_chk) if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)) for old_chk in checkpoints[cfg.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) elif PathManager.exists(old_chk): PathManager.rm(old_chk) if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( cfg.best_checkpoint_metric, suffix), ) if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] for old_chk in checkpoints[cfg.keep_best_checkpoints:]: if os.path.lexists(old_chk): os.remove(old_chk) elif PathManager.exists(old_chk): PathManager.rm(old_chk)
def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, data_buffer_size=0, disable_iterator_cache=False, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N shards (default: 1). shard_id (int, optional): which shard of the data iterator to return (default: 0). num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 1). data_buffer_size (int, optional): number of batches to preload (default: 0). disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ can_reuse_epoch_itr = (not disable_iterator_cache and self.can_reuse_epoch_itr(dataset)) if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: logger.debug( 'reusing EpochBatchIterator for epoch {}'.format(epoch)) return self.dataset_to_epoch_iter[dataset] assert isinstance(dataset, FairseqDataset) # initialize the dataset with the correct starting epoch dataset.set_epoch(epoch) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # filter examples that are too large if max_positions is not None: indices = self.filter_indices_by_size(indices, dataset, max_positions, ignore_invalid_inputs) # create mini-batches with given size constraints batch_sampler = dataset.batch_by_size( indices, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, buffer_size=data_buffer_size, ) if can_reuse_epoch_itr: self.dataset_to_epoch_iter[dataset] = epoch_iter return epoch_iter
def load_dataset(self, split, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) tokens = [t for l in ds.tokens_list for t in l] elif not self.args.raw_text and IndexedInMemoryDataset.exists( path): ds = IndexedInMemoryDataset(path, fix_lua_indexing=False) tokens = ds.buffer else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( ModifiedBlockPairDataset( tokens, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), class_positive=self.dictionary.class_positive(), class_negative=self.dictionary.class_negative(), sep=self.dictionary.sep(), vocab=self.dictionary, break_mode=self.args.break_mode, short_seq_prob=self.args.short_seq_prob, )) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = ModifiedBertDataset( dataset, sizes, self.dictionary, shuffle=self.args.shuffle_instance, seed=self.seed, mask_ratio=self.args.mask_ratio, lower=self.args.span_lower, upper=self.args.span_upper, geometric_p=self.args.geometric_p)
def __getitem__(self, idx): with data_utils.numpy_seed(43211, self.epoch, idx): return self.mmdataset[idx]
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path)) # remove tail dataset = RemoveTailDataset(dataset) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) input_options = [ make_dataset( 'input{idx}'.format(idx=idx + 1), self.source_dictionary ) for idx in range(self.args.num_classes) ] if self.args.separator_token is not None: input0 = PrependTokenDataset(input0, self.args.separator_token) src_tokens = [] for input_option in input_options: if self.args.init_token is not None: input_option = PrependTokenDataset(input_option, self.args.init_token) if self.args.max_option_length is not None: input_option = TruncateDataset(input_option, self.args.max_option_length) src_token = ConcatSentencesDataset(input_option, input0) if self.args.truncate_sequence: src_token = TruncateDataset(src_token, self.args.max_positions) src_tokens.append(src_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens[0])) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for src_token_idx in range(len(src_tokens)): dataset.update( { 'net_input{idx}'.format(idx=src_token_idx+1): { 'src_tokens': RightPadDataset( src_tokens[src_token_idx], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False), } } ) label_path = '{}.label'.format(get_path('label', split)) if os.path.exists(label_path): with open(label_path) as h: dataset.update( target=RawLabelDataset([ int(x.strip()) for x in h.readlines() ]) ) nested_dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def lang_dataset(lang): input0 = make_dataset('input0', lang, self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path('input0', lang, split)) input1 = make_dataset('input1', lang, self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) if self.args.truncate_sequence: src_tokens = TruncateDataset(src_tokens, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input': { 'src_tokens': RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(src_tokens, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens, reduce=True), } if not self.args.regression_target: label_dataset = make_dataset('label', lang, self.target_dictionary) if label_dataset is not None: dataset.update(target=OffsetTokensDataset( StripTokenDataset( label_dataset, id_to_strip=self.target_dictionary.eos(), ), offset=-self.target_dictionary.nspecial, )) else: label_path = "{0}.label".format(get_path('label', lang, split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset([ float(x.strip()) for x in open(label_path).readlines() ])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) print("| Loaded {0} with #samples: {1}".format( split, len(dataset))) return dataset
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) assert dataset is not None, "could not find dataset: {}".format( get_path(type, split)) return dataset src_tokens = make_dataset("input0", self.source_dictionary) pos_tokens = make_dataset("input1", self.pos_dictionary) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) label0_dataset = make_dataset("label0", self.label0_dictionary) label1_dataset = make_dataset("label1", self.label1_dictionary) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), pad_to_length=self._max_positions, ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "segments": { "seg_tokens": RightPadDataset( pos_tokens, pad_idx=self.pos_dictionary.pad(), pad_to_length=self._max_positions, ), "seg_lengths": NumelDataset(pos_tokens, reduce=False), }, "target0": RightPadDataset( # use 1 as padding, will be used to mask out padding when calculating loss ReplaceDataset( # replace eos and existing padding (used when some tokens should not be predicted) with -1 OffsetTokensDataset( # offset tokens to get the targets to the correct range (0,1,2,...) label0_dataset, offset=-self.label0_dictionary.nspecial, ), replace_map={ self.label0_dictionary.eos() - self.label0_dictionary.nspecial: -1, self.label0_dictionary.pad() - self.label0_dictionary.nspecial: -1, }, offsets=np.zeros(len(label0_dataset), dtype=np.int), ), pad_idx=-1, pad_to_length=self._max_positions, ), "target1": RightPadDataset( # use 1 as padding, will be used to mask out padding when calculating loss ReplaceDataset( # replace eos and existing padding (used when some tokens should not be predicted) with -1 OffsetTokensDataset( # offset tokens to get the targets to the correct range (0,1,2,...) label1_dataset, offset=-self.label1_dictionary.nspecial, ), replace_map={ self.label1_dictionary.eos() - self.label1_dictionary.nspecial: -1, self.label1_dictionary.pad() - self.label1_dictionary.nspecial: -1, }, offsets=np.zeros(len(label1_dataset), dtype=np.int), ), pad_idx=-1, pad_to_length=self._max_positions, ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(os.pathsep) assert len(paths) > 0 data_path = paths[epoch % len(paths)] languages = [ name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name)) ] print("| Training on {0} languages: {1}".format( len(languages), languages)) print("| Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = self._get_whole_word_mask() lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) lang_dataset = NestedDictionaryDataset( { 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), 'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) sample_probs = self._get_sample_prob(dataset_lengths) print( "| Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths print( "| Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + '_' + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ','.join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def __getitem_cached__(self, seed: int, epoch: int, index: int): with data_utils.numpy_seed(self.seed, self.epoch, index): item = self.dataset[index] sz = len(item) assert ( self.mask_idx not in item ), "Dataset contains mask_idx (={}), this is not expected!".format( self.mask_idx, ) if self.mask_whole_words is not None: word_begins_mask = self.mask_whole_words.gather(0, item) word_begins_idx = word_begins_mask.nonzero().view(-1) sz = len(word_begins_idx) words = np.split(word_begins_mask, word_begins_idx)[1:] assert len(words) == sz word_lens = list(map(len, words)) # make sure this is value sequence, which must have ## appeared assert '##' in self.vocab # if predict value (as classification), then we predict for all registers their actual values real_bytes = \ torch.where((item != self.vocab.index('##')) & (item != self.vocab.bos()) & (item != self.vocab.eos()))[ 0].cpu().numpy() # decide elements to mask mask = np.full(sz, False) num_mask = int( # add a random number for probabilistic rounding self.mask_prob * len(real_bytes) / float(self.mask_multiple_length) + np.random.rand()) # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) mask_idc = np.random.choice(real_bytes, num_mask, replace=False) if self.mask_stdev > 0.0: lengths = np.random.normal(self.mask_multiple_length, self.mask_stdev, size=num_mask) lengths = [max(0, int(round(x))) for x in lengths] mask_idc = np.asarray( [ mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j]) ], dtype=np.int64, ) else: mask_idc = np.concatenate( [mask_idc + i for i in range(self.mask_multiple_length)]) mask_idc = mask_idc[mask_idc < len(mask)] try: mask[mask_idc] = True except: # something wrong print("Assigning mask indexes {} to mask {} failed!".format( mask_idc, mask)) raise if self.return_masked_tokens: # exit early if we're just returning the masked tokens # (i.e., the targets for masked LM training) if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) new_item = np.full(len(mask), self.pad_idx) new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] return torch.from_numpy(new_item) # decide unmasking and random replacement rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob if rand_or_unmask_prob > 0.0: rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) if self.random_token_prob == 0.0: unmask = rand_or_unmask rand_mask = None elif self.leave_unmasked_prob == 0.0: unmask = None rand_mask = rand_or_unmask else: unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob decision = np.random.rand(sz) < unmask_prob unmask = rand_or_unmask & decision rand_mask = rand_or_unmask & (~decision) else: unmask = rand_mask = None if unmask is not None: mask = mask ^ unmask if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) new_item = np.copy(item) new_item[mask] = self.mask_idx if rand_mask is not None: num_rand = rand_mask.sum() if num_rand > 0: if self.mask_whole_words is not None: rand_mask = np.repeat(rand_mask, word_lens) num_rand = rand_mask.sum() new_item[rand_mask] = np.random.choice( len(self.vocab), num_rand, p=self.weights, ) return torch.from_numpy(new_item)
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(type, split): return os.path.join(self.args.data, type, split) def make_dataset(type, dictionary): split_path = get_path(type, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset('input0', self.source_dictionary) assert input0 is not None, 'could not find dataset: {}'.format( get_path('input0', split)) input1 = make_dataset('input1', self.source_dictionary) assert input1 is not None, 'could not find dataset: {}'.format( get_path('input1', split)) assert len(input0) == len(input1), 'input pair different length' if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) input1 = PrependTokenDataset(input1, self.args.init_token) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(input0)) if self.args.truncate_sequence: input0 = TruncateDataset(input0, self.args.max_positions) input1 = TruncateDataset(input1, self.args.max_positions) dataset = { 'id': IdDataset(), 'net_input0': { 'src_tokens': RightPadDataset( input0, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(input0, reduce=False), }, 'net_input1': { 'src_tokens': RightPadDataset( input1, pad_idx=self.source_dictionary.pad(), ), 'src_lengths': NumelDataset(input1, reduce=False), }, 'nsentences': NumSamplesDataset(), 'ntokens0': NumelDataset(input0, reduce=True), 'ntokens1': NumelDataset(input1, reduce=True), } label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): dataset.update(target=RawLabelDataset( [float(x.strip()) for x in open(label_path).readlines()])) nested_dataset = NestedDictionaryDataset( dataset, sizes=[np.maximum(input0.sizes, input1.sizes)], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def _collate( self, samples: List[Dict], pad_idx: int, eos_idx: int ): """ Does the heavy lifting for creating a batch from the input list of examples. The logic is as follows: 1. Mask the input blocks. In case has_pair is True then we have 2 blocks to mask. 2. Prepend the first masked block tensor with the special token used as sentence embedding. Eg: CLS in BERT. This happens irrespective of the value of has_pair. 3. If has_pair is True, then append the first masked block with the special separator token (eg: SEP for BERT) and compute segment label accordingly. In this case, also append the second masked block with this special separator token and compute its segment label. 4. For the targets tensor, prepend and append with padding index accordingly. 5. Concatenate all tensors. """ if len(samples) == 0: return {} # To ensure determinism, we reset the state of the PRNG after every # batch based on the seed and the first id of the batch. This ensures # that across epochs we get the same mask for the same example. This # is needed for reproducibility and is how BERT does masking # TODO: Can we add deteminism without this constraint? with data_utils.numpy_seed(self.seed + samples[0]["id"]): for s in samples: # token range is needed for replacing with random token during # masking token_range = (self.vocab.nspecial, len(self.vocab)) # mask according to specified probabilities. masked_blk_one, masked_tgt_one = self._mask_block( s["block_one"], self.mask_idx, self.pad_idx, token_range, ) tokens = np.concatenate([ [self.classif_token_idx], masked_blk_one ]) targets = np.concatenate([[self.pad_idx], masked_tgt_one]) segments = np.ones(len(tokens)) * self.segment_id # if has_pairs is True then we need to add the SEP token to both # the blocks after masking and re-compute segments based on the new # lengths. if self.has_pairs: tokens_one = np.concatenate([tokens, [self.sep_token_idx]]) targets_one = np.concatenate([targets, [self.pad_idx]]) masked_blk_two, masked_tgt_two = self._mask_block( s["block_two"], self.mask_idx, self.pad_idx, token_range) tokens_two = np.concatenate( [masked_blk_two, [self.sep_token_idx]]) targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]]) # block + 1 sep + 1 special (CLS) segments_one = np.zeros(len(tokens_one)) # block + 1 sep segments_two = np.ones(len(tokens_two)) tokens = np.concatenate([tokens_one, tokens_two]) targets = np.concatenate([targets_one, targets_two]) segments = np.concatenate([segments_one, segments_two]) s["source"] = torch.LongTensor(tokens) s["segment_labels"] = torch.LongTensor(segments) s["lm_target"] = torch.LongTensor(targets) def merge(key): return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad=False ) return { "id": torch.LongTensor([s["id"] for s in samples]), "ntokens": sum(len(s["source"]) for s in samples), "net_input": { "src_tokens": merge("source"), "segment_labels": merge("segment_labels"), }, "lm_target": merge("lm_target"), "sentence_target": torch.LongTensor( [s["sentence_target"] for s in samples] ) if self.has_pairs else None, "nsentences": len(samples), }
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 src_tokens = {} tgt_tokens = {} tgt_values = {} for field in configs.fields: split_path = os.path.join(self.args.data, field, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary[field], self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, split_path) ) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary[field].pad(), eos=self.source_dictionary[field].eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary[field].bos()) if field == configs.static_field: src_dataset_code, tgt_dataset_code = MaskTokensDataset.apply_mask( dataset, self.source_dictionary[field], pad_idx=self.source_dictionary[field].pad(), mask_idx=self.mask_idx_dict[field], seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, ) src_tokens[field] = RightPadDataset( src_dataset_code, pad_idx=self.source_dictionary[field].pad() ) tgt_tokens[field] = RightPadDataset( tgt_dataset_code, pad_idx=self.source_dictionary[field].pad() ) elif field in configs.byte_fields: src_dataset_value, tgt_dataset_value = MaskValuesDataset.apply_mask( dataset, self.source_dictionary[field], pad_idx=self.source_dictionary[field].pad(), mask_idx=self.mask_idx_dict[field], seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, ) src_tokens[field] = RightPadDataset( src_dataset_value, pad_idx=self.source_dictionary[field].pad() ) # dummy tokens are treated as 1 # TODO: assert there should not be any dummy tokens here tgt_values[field] = BytevalueDataset(tgt_dataset_value, self.source_dictionary[field]) else: src_tokens[field] = RightPadDataset( dataset, pad_idx=self.source_dictionary[field].pad() ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset_code)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": src_tokens, "src_lengths": NumelDataset(src_dataset_code, reduce=False), }, "target": { "tgt_tokens": tgt_tokens, "tgt_values": tgt_values }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset_code, reduce=True), }, sizes=[src_dataset_code.sizes], ), sort_order=[ shuffle, src_dataset_code.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, data_selector=None): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ print('Loading dataset') data_path = os.path.join(self.args.data) dataset_inst = data_utils.load_indexed_dataset( os.path.join(data_path, 'insts', split), self.instruction_dictionary, self.args.dataset_impl, combine=combine, ) dataset_state = data_utils.load_indexed_dataset( os.path.join(data_path, 'states', split), self.state_dictionary, self.args.dataset_impl, combine=combine, ) if dataset_inst is None or dataset_state is None: raise FileNotFoundError('Dataset not found: {}'.format(split)) dataset_inst = SeqOfSeqDataset(dataset_inst, self.instruction_dictionary) dataset_state = SeqOfSeqDataset(dataset_state, self.state_dictionary) dataset_pos = IRPositionDataset(os.path.join(data_path, 'pos', split)) dataset = IRDataset(dataset_inst, dataset_state, dataset_pos) block_size = self.args.function_length dataset = IRPadDataset( dataset, inst_pad_idx=self.instruction_dictionary.pad(), state_pad_idx=self.state_dictionary.pad(), inst_mask_idx=self.inst_mask_idx, state_mask_idx=self.state_mask_idx, inst_cls_idx=self.instruction_dictionary.bos(), state_cls_idx=self.state_dictionary.bos(), smallbert_insts_per_input=self.args.smallbert_insts_per_group, smallbert_states_per_input=self.args.smallbert_insts_per_group, max_length=block_size, inst_pad_length=32, state_pad_length=16, pair=True, ) labels_str = list(map(json.loads, open(os.path.join(data_path, 'label', split + '.txt')))) labels = torch.tensor([x - 1 if isinstance(x, int) else int(x.strip()) - 1 for x in labels_str]) #function_indices = [torch.tensor(json.loads(x)) for x in open(os.path.join(data_path, 'funcs', split + '.txt'))] #dataset = IRMultiFunctionDataset(dataset, function_indices, self.args.max_functions_per_program) print('| loaded {} batches from: {} and {}'.format(len(dataset), os.path.join(data_path, 'insts', split), os.path.join(data_path, 'states', split))) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.labels[split] = SortDataset(RawLabelDataset(labels), sort_order=[shuffle]) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src': dataset, }, 'target': RawLabelDataset(labels), 'indices': RawLabelDataset(torch.arange(len(dataset))), 'subset': ListDataset([split for _ in range(len(dataset))]) }, sizes=[dataset.sizes], ), sort_order=[ shuffle, dataset.sizes, ], )
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) query_tokens = [] query_masks = [] query_lengths = [] candidate_tokens = [] candidate_masks = [] candidate_lengths = [] itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test')) for sample in itr: sentence, pronoun_span, query, cand_text = sample prefix = sentence[:pronoun_span[0]].rstrip() suffix = sentence[pronoun_span[1]:] leading_space = ' ' if sentence[:pronoun_span[0]].endswith( ' ') else '' trailing_space = '' if query is not None: query_toks, query_mask = self.binarize_with_mask( query, prefix, suffix, leading_space, trailing_space, ) query_len = len(query_toks) else: query_toks, query_mask, query_len = None, None, 0 query_tokens.append(query_toks) query_masks.append(query_mask) query_lengths.append(query_len) cand_toks, cand_mask = self.binarize_with_mask( cand_text, prefix, suffix, leading_space, trailing_space, ) candidate_tokens.append(cand_toks) candidate_masks.append(cand_mask) candidate_lengths.append(cand_toks.size(0)) query_lengths = np.array(query_lengths) def get_pad_dataset_fn(tokens, length, pad_idx): return PadDataset( ListDataset(tokens, length), pad_idx=pad_idx, left_pad=False, ) query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad()) query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0) candidate_lengths = np.array(candidate_lengths) candidate_tokens = get_pad_dataset_fn(candidate_tokens, candidate_lengths, self.vocab.pad()) candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0) dataset = { 'id': IdDataset(), 'query_tokens': query_tokens, 'query_masks': query_masks, 'candidate_tokens': candidate_tokens, 'candidate_masks': candidate_masks, 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(query_tokens, reduce=True), } nested_dataset = NestedDictionaryDataset( dataset, sizes=[query_lengths], ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(query_tokens)) dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) if return_only: return dataset self.datasets[split] = dataset return self.datasets[split]
def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N shards (default: 1). shard_id (int, optional): which shard of the data iterator to return (default: 0). num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 1). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ # For default fairseq task, return same iterator across epochs # as datasets are not dynamic, can be overridden in task specific # setting. if dataset in self.dataset_to_epoch_iter: return self.dataset_to_epoch_iter[dataset] assert isinstance(dataset, FairseqDataset) # initialize the dataset with the correct starting epoch dataset.set_epoch(epoch) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # filter examples that are too large if max_positions is not None: print(max_positions) indices = self.filter_indices_by_size( indices, dataset, max_positions, ignore_invalid_inputs ) # create mini-batches with given size constraints batch_sampler = dataset.batch_by_size( indices, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, buffer_size=getattr(self.args, 'data_buffer_size', 0), ) self.dataset_to_epoch_iter[dataset] = epoch_iter return epoch_iter
def collate( samples, pad_idx, chunk_width, chunk_left_context, chunk_right_context, label_delay, seed, epoch, pad_to_length=None, pad_to_multiple=1, src_bucketed=False, random_chunking=True, ): if len(samples) == 0: return {} def merge(key, pad_to_length=None): if key == "source": return speech_utils.collate_frames( [s[key] for s in samples], 0.0, pad_to_length=pad_to_length, pad_to_multiple=pad_to_multiple, ) elif key == "target": return data_utils.collate_tokens( [s[key] for s in samples], pad_idx=pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, pad_to_length=pad_to_length, pad_to_multiple=pad_to_multiple, ) else: raise ValueError("Invalid key.") def chunking(src_item, tgt_item, tgt_start): # make a src chunk in the range [begin_src, end_src) begin_src = max(0, tgt_start + label_delay - chunk_left_context) # ok if end_src past the end of utterance end_src = tgt_start + label_delay + chunk_width + chunk_right_context # replication pad if necessary left_pad = max(0, chunk_left_context - tgt_start - label_delay) right_pad = max(0, end_src - src_item.size(0)) src_item = src_item[begin_src:end_src] if left_pad > 0 or right_pad > 0: src_item = F.pad( src_item.t().unsqueeze(0), (left_pad, right_pad), mode="replicate", ).squeeze(0).t() if tgt_item is not None: # make a tgt chunk in the range [begin_tgt, end_tgt) begin_tgt = tgt_start end_tgt = tgt_start + chunk_width # ok if past the end of utterance # replication pad if necessary right_pad = max(0, end_tgt - tgt_item.size(0)) tgt_item = tgt_item[begin_tgt:end_tgt] if right_pad > 0: tgt_item = torch.cat( (tgt_item, tgt_item.new_full((right_pad, ), pad_idx)), 0) return src_item, tgt_item if chunk_width is None or random_chunking: if chunk_width is not None: # usually for chunk-wise train data # no need to sort as all chunks have exactly the same length for s in samples: with data_utils.numpy_seed(seed, epoch, s["id"]): # generate a chunk by sampling the index of its first label f = np.random.randint(s["source"].size(0) - chunk_width + 1) s["source"], s["target"] = chunking(s["source"], s["target"], f) elif label_delay != 0: # shift source according to label_delay if label_delay > 0: left_pad, right_pad = 0, label_delay else: left_pad, right_pad = -label_delay, 0 for s in samples: src_item = s["source"] src_item = F.pad( src_item.t().unsqueeze(0), (left_pad, right_pad), mode="replicate", ).squeeze(0).t() if label_delay > 0: s["source"] = src_item[label_delay:] else: s["source"] = src_item[:label_delay] if pad_to_length is not None or src_bucketed: src_lengths = torch.IntTensor( [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples]) else: src_lengths = torch.IntTensor( [s["source"].size(0) for s in samples]) id = torch.LongTensor([s["id"] for s in samples]) utt_id = [s["utt_id"] for s in samples] src_frames = merge( "source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) target = None if samples[0].get("target", None) is not None: target = merge( "target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, ) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() text = None if samples[0].get("text", None) is not None: text = [s["text"] for s in samples] if chunk_width is None: # for whole utterances (i.e., no chunking) # sort by descending source length src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) utt_id = [utt_id[i] for i in sort_order.numpy()] src_frames = src_frames.index_select(0, sort_order) if target is not None: target = target.index_select(0, sort_order) if text is not None: text = [text[i] for i in sort_order.numpy()] batch = { "id": id, "utt_id": utt_id, "nsentences": len(samples), "ntokens": ntokens, "net_input": { "src_tokens": src_frames, "src_lengths": src_lengths, }, "target": target, "text": text, } return batch else: # sequential chunking, usually for chunk-wise test data if pad_to_length is not None or src_bucketed: src_lengths = torch.IntTensor( [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples]) else: src_lengths = torch.IntTensor( [s["source"].size(0) for s in samples]) id = torch.LongTensor([s["id"] for s in samples]) utt_id = [s["utt_id"] for s in samples] ori_source = [s["source"] for s in samples] ori_target = [s["target"] for s in samples] text = None if samples[0].get("text", None) is not None: text = [s["text"] for s in samples] max_length = max(src.size(0) for src in ori_source) num_chunks = (max_length + chunk_width - 1) // chunk_width batches = [] for k in range(num_chunks): f = k * chunk_width for i, s in enumerate(samples): if f < src_lengths[i].item(): s["source"], s["target"] = chunking( ori_source[i], ori_target[i], f) else: s["source"] = ori_source[i].new_zeros( chunk_width + chunk_left_context + chunk_right_context, ori_source[i].size(1)) s["target"] = (ori_target[i].new_full( (chunk_width, ), pad_idx) if ori_target[i] is not None else None) src_frames = merge( "source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) src_chunk_lengths = torch.IntTensor( [s["source"].size(0) for s in samples]) target = None if samples[0].get("target", None) is not None: target = merge( "target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, ) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() batch = { "id": id, "utt_id": utt_id, "nsentences": len(samples) if k == 0 else 0, "ntokens": ntokens, "net_input": { "src_tokens": src_frames, "src_lengths": src_chunk_lengths, }, "target": target, "text": text, } batches.append(batch) return batches
def __getitem__(self, index: int): with data_utils.numpy_seed(self.seed, self.epoch, index): item = self.dataset[index] sz = len(item) assert self.mask_idx not in item, \ 'Dataset contains mask_idx (={}), this is not expected!'.format( self.mask_idx, ) if self.mask_whole_words is not None: word_begins_mask = self.mask_whole_words.gather(0, item) word_begins_idx = word_begins_mask.nonzero().view(-1) sz = len(word_begins_idx) words = np.split(word_begins_mask, word_begins_idx)[1:] assert len(words) == sz word_lens = list(map(len, words)) # decide elements to mask mask = np.full(sz, False) num_mask = int( # add a random number for probabilistic rounding self.mask_prob * sz + np.random.rand() ) mask[np.random.choice(sz, num_mask, replace=False)] = True if self.return_masked_tokens: # exit early if we're just returning the masked tokens # (i.e., the targets for masked LM training) if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) new_item = np.full(len(mask), self.pad_idx) new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] return torch.from_numpy(new_item) # decide unmasking and random replacement rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob if rand_or_unmask_prob > 0.0: rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) if self.random_token_prob == 0.0: unmask = rand_or_unmask rand_mask = None elif self.leave_unmasked_prob == 0.0: unmask = None rand_mask = rand_or_unmask else: unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob decision = np.random.rand(sz) < unmask_prob unmask = rand_or_unmask & decision rand_mask = rand_or_unmask & (~decision) else: unmask = rand_mask = None if unmask is not None: mask = mask ^ unmask if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) new_item = np.copy(item) new_item[mask] = self.mask_idx if rand_mask is not None: num_rand = rand_mask.sum() if num_rand > 0: if self.mask_whole_words is not None: rand_mask = np.repeat(rand_mask, word_lens) num_rand = rand_mask.sum() new_item[rand_mask] = np.random.choice( len(self.vocab), num_rand, p=self.weights, ) return torch.from_numpy(new_item)
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ def binarize(s, append_bos=False): if self.bpe is not None: s = self.bpe.encode(s) tokens = self.vocab.encode_line( s, append_eos=True, add_if_not_exist=False, ).long() if append_bos and self.args.init_token is not None: tokens = torch.cat( [tokens.new([self.args.init_token]), tokens]) return tokens if data_path is None: data_path = os.path.join(self.args.data, split + '.jsonl') if not os.path.exists(data_path): raise FileNotFoundError('Cannot find data: {}'.format(data_path)) src_tokens = [[] for i in range(self.args.num_classes)] src_lengths = [[] for i in range(self.args.num_classes)] labels = [] with open(data_path) as h: for line in h: example = json.loads(line.strip()) if 'answerKey' in example: label = ord(example['answerKey']) - ord('A') labels.append(label) question = example['question']['stem'] assert len( example['question']['choices']) == self.args.num_classes # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` question = 'Q: ' + question question_toks = binarize(question, append_bos=True) for i, choice in enumerate(example['question']['choices']): src = 'A: ' + choice['text'] src_bin = torch.cat([question_toks, binarize(src)]) src_tokens[i].append(src_bin) src_lengths[i].append(len(src_bin)) assert all( len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes)) assert len(src_tokens[0]) == len(src_lengths[0]) assert len(labels) == 0 or len(labels) == len(src_tokens[0]) for i in range(self.args.num_classes): src_lengths[i] = np.array(src_lengths[i]) src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) src_lengths[i] = ListDataset(src_lengths[i]) dataset = { 'id': IdDataset(), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_tokens[0], reduce=True), } for i in range(self.args.num_classes): dataset.update({ 'net_input{}'.format(i + 1): { 'src_tokens': RightPadDataset( src_tokens[i], pad_idx=self.source_dictionary.pad(), ), 'src_lengths': src_lengths[i], } }) if len(labels) > 0: dataset.update({'target': RawLabelDataset(labels)}) dataset = NestedDictionaryDataset( dataset, sizes=[ np.maximum.reduce( [src_token.sizes for src_token in src_tokens]) ], ) with data_utils.numpy_seed(self.args.seed): dataset = SortDataset( dataset, # shuffle sort_order=[np.random.permutation(len(dataset))], ) print('| Loaded {} with {} samples'.format(split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]
def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = self.args.data.split(':') assert len(paths) > 0 data_path = paths[epoch % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError('Dataset not found: {} ({})'.format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) print('| loaded {} blocks from: {}'.format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets if self.args.mask_whole_words: bpe = encoders.build_bpe(self.args) assert bpe is not None def is_beginning_of_word(i): if i < self.source_dictionary.nspecial: # special elements are always considered beginnings return True tok = self.source_dictionary[i] if tok.startswith('madeupword'): return True try: return bpe.is_beginning_of_word(tok) except ValueError: return True mask_whole_words = torch.ByteTensor( list( map(is_beginning_of_word, range(len(self.source_dictionary))))) else: mask_whole_words = None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { 'id': IdDataset(), 'net_input': { 'src_tokens': PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, 'target': PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). num_shards (int, optional): shard the data iterator into N shards (default: 1). shard_id (int, optional): which shard of the data iterator to return (default: 0). num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ assert isinstance(dataset, FairseqDataset) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # filter examples that are too large indices = data_utils.filter_by_size( indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs), ) # create mini-batches with given size constraints batch_sampler = data_utils.batch_by_size( indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) # return a reusable, sharded iterator return iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, )
def load_dataset(self, split, epoch=0, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] paths = self.args.data.split(os.pathsep) assert len(paths) > 0 data_path = paths[epoch % len(paths)] print("| data_path", data_path) for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(data_path, split_k) ds = indexed_dataset.make_dataset( path, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.dictionary, ) if ds is None: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format(split, data_path)) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( BlockPairDataset( ds, self.dictionary, ds.sizes, self.args.tokens_per_sample, break_mode=self.args.break_mode, doc_break_size=1, )) print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = MaskedLMDataset( dataset=dataset, sizes=sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.cls(), sep_token_idx=self.dictionary.sep(), shuffle=self.args.shuffle_dataset, seed=self.seed, )
def get_batch_iterator(self, dataset, assistant=None, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, batch_method='sentences'): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~fairseq.data.FairseqDataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch. Default: ``None`` max_sentences (int, optional): max number of sentences in each batch. Default: ``None`` max_positions (optional): max sentence length supported by the model. Default: ``None`` ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long. Default: ``False`` required_batch_size_multiple (int, optional): require batch size to be a multiple of N. Default: ``1`` seed (int, optional): seed for random number generator for reproducibility. Default: ``1`` num_shards (int, optional): shard the data iterator into N shards. Default: ``1`` shard_id (int, optional): which shard of the data iterator to return. Default: ``0`` Returns: ~fairseq.iterators.EpochBatchIterator: a batched iterator over the given dataset split """ assert isinstance(dataset, FairseqDataset) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # filter examples that are too large indices = data_utils.filter_by_size( indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs), ) # create mini-batches with given size constraints if assistant is not None: assistant.associate_data(dataset, indices) else: batch_sampler = data_utils.batch_by_size( indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) if assistant is not None: # return a reusable, sharded iterator return iterators.AssistantEpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, assistant=assistant, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, shard_num=num_shards, shard_id=shard_id, batch_method=batch_method, seed=seed, ) else: # return a reusable, sharded iterator return iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, )
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split (e.g., train, valid, test).""" def get_path(key, split): return os.path.join(self.args.data, key, split) def make_dataset(key, dictionary): split_path = get_path(key, split) dataset = data_utils.load_indexed_dataset( split_path, dictionary, self.args.dataset_impl, combine=combine, ) return dataset input0 = make_dataset("input0", self.source_dictionary) assert input0 is not None, "could not find dataset: {}".format( get_path("input0", split)) input1 = make_dataset("input1", self.source_dictionary) if self.args.init_token is not None: input0 = PrependTokenDataset(input0, self.args.init_token) if input1 is None: src_tokens = input0 else: if self.args.separator_token is not None: input1 = PrependTokenDataset(input1, self.args.separator_token) src_tokens = ConcatSentencesDataset(input0, input1) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_tokens)) src_tokens = maybe_shorten_dataset( src_tokens, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.max_positions, self.args.seed, ) dataset = { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_tokens, reduce=False), }, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } if self.args.add_prev_output_tokens: prev_tokens_dataset = RightPadDataset( RollDataset(src_tokens, 1), pad_idx=self.dictionary.pad(), ) dataset["net_input"].update( prev_output_tokens=prev_tokens_dataset, ) label_path = "{0}.npz".format(get_path("label", split)) if os.path.exists(label_path): csr_matrix = load_npz(label_path) dataset.update(target=CSRLabelDataset(csr_matrix)) nested_dataset = NestedDictionaryDataset( dataset, sizes=[src_tokens.sizes], ) if self.args.no_shuffle: dataset = nested_dataset else: dataset = SortDataset( nested_dataset, # shuffle sort_order=[shuffle], ) logger.info("Loaded {0} with #samples: {1}".format( split, len(dataset))) self.datasets[split] = dataset return self.datasets[split]