示例#1
0
    def get_batch_iterator(
        self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
        ignore_invalid_inputs=False, required_batch_size_multiple=1,
        seed=1, num_shards=1, shard_id=0,
    ):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch.
                Default: ``None``
            max_sentences (int, optional): max number of sentences in each
                batch. Default: ``None``
            max_positions (optional): max sentence length supported by the
                model. Default: ``None``
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long. Default: ``False``
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N. Default: ``1``
            seed (int, optional): seed for random number generator for
                reproducibility. Default: ``1``
            num_shards (int, optional): shard the data iterator into N
                shards. Default: ``1``
            shard_id (int, optional): which shard of the data iterator to
                return. Default: ``0``

        Returns:
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        assert isinstance(dataset, FairseqDataset)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        indices = data_utils.filter_by_size(
            indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
        )

        # create mini-batches with given size constraints
        batch_sampler = data_utils.batch_by_size(
            indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        # return a reusable, sharded iterator
        return iterators.EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
        )
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset('input0', self.source_dictionary)
        assert input0 is not None, 'could not find dataset: {}'.format(
            get_path(type, split))
        input1 = make_dataset('input1', self.source_dictionary)

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        if input1 is None:
            src_tokens = input0
        else:
            if self.args.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.args.separator_token)

            src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        if self.args.truncate_sequence:
            src_tokens = TruncateDataset(src_tokens, self.args.max_positions)

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens':
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths':
                NumelDataset(src_tokens, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        if self.args.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad(),
            )
            dataset['net_input'].update(
                prev_output_tokens=prev_tokens_dataset, )

        if not self.args.regression_target:
            label_dataset = make_dataset('label', self.target_dictionary)
            if label_dataset is not None:
                dataset.update(target=OffsetTokensDataset(
                    StripTokenDataset(
                        label_dataset,
                        id_to_strip=self.target_dictionary.eos(),
                    ),
                    offset=-self.target_dictionary.nspecial,
                ))
        else:
            label_path = "{0}.label".format(get_path('label', split))
            if os.path.exists(label_path):
                dataset.update(target=RawLabelDataset(
                    [float(x.strip()) for x in open(label_path).readlines()]))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        print("| Loaded {0} with #samples: {1}".format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
示例#3
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample; periodically print out
        randomly sampled predictions if model is in training mode, otherwise
        aggregate word error stats for validation.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        dict = self.scorer.dict
        if model.training:
            if ((len(self.args.scheduled_sampling_probs) > 1
                 or self.args.scheduled_sampling_probs[0] < 1.0) and
                    self.epoch >= self.args.start_scheduled_sampling_epoch):
                # scheduled sampling
                ss_prob = self.args.scheduled_sampling_probs[min(
                    self.epoch - self.args.start_scheduled_sampling_epoch,
                    len(self.args.scheduled_sampling_probs) - 1)]
                assert isinstance(model.decoder, FairseqIncrementalDecoder)
                incremental_states = {}
                encoder_input = {
                    k: v
                    for k, v in sample['net_input'].items()
                    if k != 'prev_output_tokens'
                }
                encoder_out = model.encoder(**encoder_input)
                target = sample['target']
                tokens = sample['net_input']['prev_output_tokens']
                lprobs = []
                pred = None
                for step in range(target.size(1)):
                    if step > 0:
                        sampling_mask = torch.rand(
                            [target.size(0), 1],
                            device=target.device,
                        ).lt(ss_prob)
                        feed_tokens = torch.where(
                            sampling_mask,
                            tokens[:, step:step + 1],
                            pred,
                        )
                    else:
                        feed_tokens = tokens[:, step:step + 1]
                    log_probs, _ = self._decode(
                        feed_tokens,
                        model,
                        encoder_out,
                        incremental_states,
                    )
                    pred = log_probs.argmax(-1, keepdim=True)
                    lprobs.append(log_probs)
                lprobs = torch.stack(lprobs, dim=1)
            else:
                # normal training
                net_output = model(**sample['net_input'])
                lprobs = model.get_normalized_probs(net_output, log_probs=True)
                target = model.get_targets(sample, net_output)
        else:
            assert isinstance(model.decoder, FairseqIncrementalDecoder)
            incremental_states = {}
            encoder_input = {
                k: v
                for k, v in sample['net_input'].items()
                if k != 'prev_output_tokens'
            }
            encoder_out = model.encoder(**encoder_input)
            target = sample['target']
            # make the maximum decoding length equal to at least the length of
            # target, and the length of encoder_out if possible
            maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1))
            tokens = target.new_full([target.size(0), maxlen + 2],
                                     self.padding_idx)
            tokens[:, 0] = dict.eos()
            lprobs = []
            attn = [] if getattr(model.decoder, 'need_attn', False) else None
            dummy_log_probs = encoder_out['encoder_out'][0].new_full(
                [target.size(0), len(dict)], -np.log(len(dict)))
            for step in range(maxlen + 1):  # one extra step for EOS marker
                is_eos = tokens[:, step].eq(dict.eos())
                # if all predictions are finished (i.e., ended with eos),
                # pad lprobs to target length with dummy log probs,
                # truncate tokens up to this step and break
                if step > 0 and is_eos.sum() == is_eos.size(0):
                    for _ in range(step, target.size(1)):
                        lprobs.append(dummy_log_probs)
                    tokens = tokens[:, :step + 1]
                    break
                log_probs, attn_scores = self._decode(
                    tokens[:, :step + 1],
                    model,
                    encoder_out,
                    incremental_states,
                )
                tokens[:, step + 1] = log_probs.argmax(-1)
                if step > 0:  # deal with finished predictions
                    # make log_probs uniform if the previous output token is EOS
                    # and add consecutive EOS to the end of prediction
                    log_probs[is_eos, :] = -np.log(log_probs.size(1))
                    tokens[is_eos, step + 1] = dict.eos()
                if step < target.size(1):
                    lprobs.append(log_probs)
                if getattr(model.decoder, 'need_attn', False):
                    attn.append(attn_scores)
            # bsz x min(tgtlen, maxlen + 1) x vocab_size
            lprobs = torch.stack(lprobs, dim=1)
            if getattr(model.decoder, 'need_attn', False):
                # bsz x (maxlen + 1) x (length of encoder_out)
                attn = torch.stack(attn, dim=1)
        # word error stats code starts
        if (not model.training
                or (self.num_updates // self.args.print_interval >
                    (self.num_updates - 1) // self.args.print_interval)):
            pred = lprobs.argmax(-1).cpu() if model.training else \
                tokens[:, 1:].data.cpu()  # bsz x len

            if not model.training:  # validation step, compute WER stats with scorer
                assert pred.size(0) == target.size(0)
                self.scorer.reset()
                for i in range(target.size(0)):
                    utt_id = sample['utt_id'][i]
                    id = sample['id'].data[i].item()
                    # ref_tokens = dict.string(target.data[i])
                    # if it is a dummy batch (e.g., a "padding" batch in a sharded
                    # dataset), id might exceeds the dataset size; in that case we
                    # just skip it
                    if id < len(self.valid_tgt_dataset):
                        ref_tokens = self.valid_tgt_dataset.get_original_tokens(
                            id)
                        pred_tokens = dict.string(pred.data[i])
                        self.scorer.add_evaluation(
                            utt_id,
                            ref_tokens,
                            pred_tokens,
                            bpe_symbol=self.args.remove_bpe,
                        )
            else:  # print a randomly sampled result every print_interval updates
                assert pred.size() == target.size()
                with data_utils.numpy_seed(self.num_updates):
                    i = np.random.randint(0, len(sample['id']))
                id = sample['id'].data[i].item()
                length = utils.strip_pad(target.data[i],
                                         self.padding_idx).size(0)
                # ref_one = dict.tokens_to_sentence(dict.string(target.data[i]))
                ref_one = self.train_tgt_dataset.get_original_text(
                    id,
                    dict,
                    bpe_symbol=self.args.remove_bpe,
                )
                pred_one = dict.tokens_to_sentence(
                    dict.string(pred.data[i][:length]),
                    bpe_symbol=self.args.remove_bpe,
                )
                print('| sample REF: ' + ref_one)
                print('| sample PRD: ' + pred_one)
        # word error stats code ends
        lprobs = lprobs.view(-1, lprobs.size(-1))
        loss = F.nll_loss(
            lprobs,
            target.view(-1),
            ignore_index=self.padding_idx,
            reduction='sum' if reduce else 'none',
        )
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        if not model.training:  # do not compute word error in training mode
            logging_output['word_error'] = self.scorer.tot_word_error()
            logging_output['word_count'] = self.scorer.tot_word_count()
            logging_output['char_error'] = self.scorer.tot_char_error()
            logging_output['char_count'] = self.scorer.tot_char_count()
        return loss, sample_size, logging_output
示例#4
0
    def get_batch_iterator(
        self,
        dataset,
        max_tokens=None,
        max_sentences=None,
        max_positions=None,
        ignore_invalid_inputs=False,
        required_batch_size_multiple=1,
        seed=1,
        num_shards=1,
        shard_id=0,
        num_workers=0,
        epoch=0,
    ):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch
            (default: None). (we do not use it anymore and must ensure that everywhere)
            max_sentences (int, optional): max number of sentences in each
                batch (default: None).
            max_positions (optional): max sentence length supported by the
                model (default: None). (we do not use it anymore and must ensure that everywhere)
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N (default: 1).
            seed (int, optional): seed for random number generator for
                reproducibility (default: 1).
            num_shards (int, optional): shard the data iterator into N
                shards (default: 1).
            shard_id (int, optional): which shard of the data iterator to
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).
            epoch (int, optional): the epoch to start the iterator from
                (default: 0).

        Returns:
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        assert isinstance(dataset, FairseqDataset)

        # get indices ordered by example size
        #this should be already fixed in the dataset
        with data_utils.numpy_seed(seed):
            #will get our ordered_indices
            #will be filtering by our size
            indices = dataset.ordered_indices()

        # filter_by_size was removed as we believe we do not need it (Christine)(18-12-2019)

        # create mini-batches which has batches with given size constraints
        # it just adjusts the batch by size (works by batch_by_size implemented above)
        # should be tested later if it suits the framework(Christine)(18-12-2019)
        batch_sampler = batch_by_size(indices, max_sentences=max_sentences)

        # batches should be here returned correctly, mini batches should be ???
        # return a reusable, sharded iterator
        return iterators.EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
            num_workers=num_workers,
            epoch=epoch,
        )
 def shuffle_batches(batches, seed):
     with data_utils.numpy_seed(seed):
         np.random.shuffle(batches)
     return batches
示例#6
0
    def load_dataset(self,
                     split,
                     epoch=0,
                     combine=False,
                     data_path=None,
                     return_only=False,
                     **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if data_path is None:
            data_path = os.path.join(self.args.data, split + '.jsonl')
        if not os.path.exists(data_path):
            raise FileNotFoundError('Cannot find data: {}'.format(data_path))

        query_tokens = []
        query_masks = []
        query_lengths = []
        candidate_tokens = []
        candidate_masks = []
        candidate_lengths = []
        labels = []

        for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(
                data_path):
            prefix = sentence[:pronoun_span.start].text
            suffix = sentence[pronoun_span.end:].text_with_ws

            # spaCy spans include trailing spaces, but we need to know about
            # leading spaces for the GPT-2 BPE
            leading_space = ' ' if sentence[:pronoun_span.
                                            start].text_with_ws.endswith(
                                                ' ') else ''
            trailing_space = ' ' if pronoun_span.text_with_ws.endswith(
                ' ') else ''

            # get noun phrases, excluding pronouns and anything overlapping with the query
            cand_spans = wsc_utils.filter_noun_chunks(
                wsc_utils.extended_noun_chunks(sentence),
                exclude_pronouns=True,
                exclude_query=query,
                exact_match=False,
            )

            if query is not None:
                query_toks, query_mask = self.binarize_with_mask(
                    query, prefix, suffix, leading_space, trailing_space)
                query_len = len(query_toks)
            else:
                query_toks, query_mask, query_len = None, None, 0

            query_tokens.append(query_toks)
            query_masks.append(query_mask)
            query_lengths.append(query_len)

            cand_toks, cand_masks = [], []
            for cand_span in cand_spans:
                toks, mask = self.binarize_with_mask(
                    cand_span.text,
                    prefix,
                    suffix,
                    leading_space,
                    trailing_space,
                )
                cand_toks.append(toks)
                cand_masks.append(mask)

            # collate candidates
            cand_toks = data_utils.collate_tokens(cand_toks,
                                                  pad_idx=self.vocab.pad())
            cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
            assert cand_toks.size() == cand_masks.size()

            candidate_tokens.append(cand_toks)
            candidate_masks.append(cand_masks)
            candidate_lengths.append(cand_toks.size(1))

            labels.append(label)

        query_lengths = np.array(query_lengths)
        query_tokens = ListDataset(query_tokens, query_lengths)
        query_masks = ListDataset(query_masks, query_lengths)

        candidate_lengths = np.array(candidate_lengths)
        candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
        candidate_masks = ListDataset(candidate_masks, candidate_lengths)

        labels = ListDataset(labels, [1] * len(labels))

        dataset = {
            'id': IdDataset(),
            'query_tokens': query_tokens,
            'query_masks': query_masks,
            'candidate_tokens': candidate_tokens,
            'candidate_masks': candidate_masks,
            'labels': labels,
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(query_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[query_lengths],
        )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(query_tokens))
        dataset = SortDataset(
            nested_dataset,
            # shuffle
            sort_order=[shuffle],
        )

        if return_only:
            return dataset

        self.datasets[split] = dataset
        return self.datasets[split]
示例#7
0
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
    from fairseq import meters

    # only one worker should attempt to create the required dir
    if trainer.data_parallel_rank == 0:
        os.makedirs(cfg.save_dir, exist_ok=True)

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if cfg.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if cfg.no_save:
        return

    trainer.consolidate_optimizer(
    )  # TODO(SS): do we need this if no_save_optimizer_state

    if not trainer.should_save_checkpoint_on_current_rank:
        if trainer.always_call_state_dict_during_save_checkpoint:
            trainer.state_dict()
        return

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    logger.info(
        f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")

    def is_better(a, b):
        return a >= b if cfg.maximize_best_checkpoint_metric else a <= b

    suffix = trainer.checkpoint_suffix
    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}{}.pt".format(
        epoch, suffix)] = (end_of_epoch and not cfg.no_epoch_checkpoints
                           and epoch % cfg.save_interval == 0)
    checkpoint_conds["checkpoint_{}_{}{}.pt".format(
        epoch, updates,
        suffix)] = (not end_of_epoch and cfg.save_interval_updates > 0
                    and updates % cfg.save_interval_updates == 0)
    checkpoint_conds["checkpoint_best{}.pt".format(
        suffix)] = val_loss is not None and (
            not hasattr(save_checkpoint, "best")
            or is_better(val_loss, save_checkpoint.best))
    if val_loss is not None and cfg.keep_best_checkpoints > 0:
        worst_best = getattr(save_checkpoint, "best", None)
        chkpts = checkpoint_paths(
            cfg.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
                cfg.best_checkpoint_metric, suffix),
        )
        if len(chkpts) > 0:
            p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
            worst_best = float(
                p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
        # add random digits to resolve ties
        with data_utils.numpy_seed(epoch, updates, val_loss):
            rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)

        checkpoint_conds["checkpoint.best_{}_{:.3f}{}{}.pt".format(
            cfg.best_checkpoint_metric, val_loss, rand_sfx,
            suffix)] = worst_best is None or is_better(val_loss, worst_best)
    checkpoint_conds["checkpoint_last{}.pt".format(
        suffix)] = not cfg.no_last_checkpoints

    extra_state = {
        "train_iterator": epoch_itr.state_dict(),
        "val_loss": val_loss
    }
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(cfg.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            if cfg.write_checkpoints_asynchronously:
                # TODO[ioPath]: Need to implement a delayed asynchronous
                # file copying/moving feature.
                logger.warning(
                    f"ioPath is not copying {checkpoints[0]} to {cp} "
                    "since async write mode is on.")
            else:
                assert PathManager.copy(
                    checkpoints[0], cp,
                    overwrite=True), f"Failed to copy {checkpoints[0]} to {cp}"

        write_timer.stop()
        logger.info(
            "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)"
            .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum))

    if not end_of_epoch and cfg.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        if cfg.keep_interval_updates_pattern == -1:
            checkpoints = checkpoint_paths(
                cfg.save_dir,
                pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix))
        else:
            checkpoints = checkpoint_paths(
                cfg.save_dir,
                pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
                keep_match=True,
            )
            checkpoints = [
                x[0] for x in checkpoints
                if x[1] % cfg.keep_interval_updates_pattern != 0
            ]

        for old_chk in checkpoints[cfg.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
            elif PathManager.exists(old_chk):
                PathManager.rm(old_chk)

    if cfg.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix))
        for old_chk in checkpoints[cfg.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
            elif PathManager.exists(old_chk):
                PathManager.rm(old_chk)

    if cfg.keep_best_checkpoints > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            cfg.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
                cfg.best_checkpoint_metric, suffix),
        )
        if not cfg.maximize_best_checkpoint_metric:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[cfg.keep_best_checkpoints:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
            elif PathManager.exists(old_chk):
                PathManager.rm(old_chk)
示例#8
0
    def get_batch_iterator(
        self,
        dataset,
        max_tokens=None,
        max_sentences=None,
        max_positions=None,
        ignore_invalid_inputs=False,
        required_batch_size_multiple=1,
        seed=1,
        num_shards=1,
        shard_id=0,
        num_workers=0,
        epoch=1,
        data_buffer_size=0,
        disable_iterator_cache=False,
    ):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
            max_sentences (int, optional): max number of sentences in each
                batch (default: None).
            max_positions (optional): max sentence length supported by the
                model (default: None).
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N (default: 1).
            seed (int, optional): seed for random number generator for
                reproducibility (default: 1).
            num_shards (int, optional): shard the data iterator into N
                shards (default: 1).
            shard_id (int, optional): which shard of the data iterator to
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).
            epoch (int, optional): the epoch to start the iterator from
                (default: 1).
            data_buffer_size (int, optional): number of batches to
                preload (default: 0).
            disable_iterator_cache (bool, optional): don't cache the
                EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
                (default: False).
        Returns:
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        can_reuse_epoch_itr = (not disable_iterator_cache
                               and self.can_reuse_epoch_itr(dataset))
        if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
            logger.debug(
                'reusing EpochBatchIterator for epoch {}'.format(epoch))
            return self.dataset_to_epoch_iter[dataset]

        assert isinstance(dataset, FairseqDataset)

        # initialize the dataset with the correct starting epoch
        dataset.set_epoch(epoch)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        if max_positions is not None:
            indices = self.filter_indices_by_size(indices, dataset,
                                                  max_positions,
                                                  ignore_invalid_inputs)

        # create mini-batches with given size constraints
        batch_sampler = dataset.batch_by_size(
            indices,
            max_tokens=max_tokens,
            max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        # return a reusable, sharded iterator
        epoch_iter = iterators.EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
            num_workers=num_workers,
            epoch=epoch,
            buffer_size=data_buffer_size,
        )

        if can_reuse_epoch_itr:
            self.dataset_to_epoch_iter[dataset] = epoch_iter

        return epoch_iter
    def load_dataset(self, split, combine=False):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
                tokens = [t for l in ds.tokens_list for t in l]
            elif not self.args.raw_text and IndexedInMemoryDataset.exists(
                    path):
                ds = IndexedInMemoryDataset(path, fix_lua_indexing=False)
                tokens = ds.buffer
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))
            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    ModifiedBlockPairDataset(
                        tokens,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        pad=self.dictionary.pad(),
                        class_positive=self.dictionary.class_positive(),
                        class_negative=self.dictionary.class_negative(),
                        sep=self.dictionary.sep(),
                        vocab=self.dictionary,
                        break_mode=self.args.break_mode,
                        short_seq_prob=self.args.short_seq_prob,
                    ))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = ModifiedBertDataset(
            dataset,
            sizes,
            self.dictionary,
            shuffle=self.args.shuffle_instance,
            seed=self.seed,
            mask_ratio=self.args.mask_ratio,
            lower=self.args.span_lower,
            upper=self.args.span_upper,
            geometric_p=self.args.geometric_p)
示例#10
0
 def __getitem__(self, idx):
     with data_utils.numpy_seed(43211, self.epoch, idx):
         return self.mmdataset[idx]
示例#11
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        logger.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

        # remove tail
        dataset = RemoveTailDataset(dataset)

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id': IdDataset(),
                    'net_input': {
                        'src_tokens': RightPadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        'src_lengths': NumelDataset(src_dataset, reduce=False),
                    },
                    'target': RightPadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'nsentences': NumSamplesDataset(),
                    'ntokens': NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset('input0', self.source_dictionary)
        input_options = [
            make_dataset(
                'input{idx}'.format(idx=idx + 1),
                self.source_dictionary
            )
            for idx in range(self.args.num_classes)
        ]

        if self.args.separator_token is not None:
            input0 = PrependTokenDataset(input0, self.args.separator_token)

        src_tokens = []
        for input_option in input_options:
            if self.args.init_token is not None:
                input_option = PrependTokenDataset(input_option, self.args.init_token)
            if self.args.max_option_length is not None:
                input_option = TruncateDataset(input_option, self.args.max_option_length)
            src_token = ConcatSentencesDataset(input_option, input0)
            if self.args.truncate_sequence:
                src_token = TruncateDataset(src_token, self.args.max_positions)
            src_tokens.append(src_token)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens[0]))

        dataset = {
            'id': IdDataset(),
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens[0], reduce=True),
        }

        for src_token_idx in range(len(src_tokens)):
            dataset.update(
                {
                    'net_input{idx}'.format(idx=src_token_idx+1): {
                        'src_tokens': RightPadDataset(
                            src_tokens[src_token_idx],
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False),
                    }
                }
            )

        label_path = '{}.label'.format(get_path('label', split))
        if os.path.exists(label_path):
            with open(label_path) as h:
                dataset.update(
                    target=RawLabelDataset([
                        int(x.strip()) for x in h.readlines()
                    ])
                )

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        print("| Loaded {0} with #samples: {1}".format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
示例#13
0
        def lang_dataset(lang):
            input0 = make_dataset('input0', lang, self.source_dictionary)
            assert input0 is not None, 'could not find dataset: {}'.format(
                get_path('input0', lang, split))
            input1 = make_dataset('input1', lang, self.source_dictionary)

            if self.args.init_token is not None:
                input0 = PrependTokenDataset(input0, self.args.init_token)

            if input1 is None:
                src_tokens = input0
            else:
                if self.args.separator_token is not None:
                    input1 = PrependTokenDataset(input1,
                                                 self.args.separator_token)

                src_tokens = ConcatSentencesDataset(input0, input1)

            with data_utils.numpy_seed(self.args.seed):
                shuffle = np.random.permutation(len(src_tokens))

            if self.args.truncate_sequence:
                src_tokens = TruncateDataset(src_tokens,
                                             self.args.max_positions)

            dataset = {
                'id': IdDataset(),
                'net_input': {
                    'src_tokens':
                    RightPadDataset(
                        src_tokens,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'src_lengths':
                    NumelDataset(src_tokens, reduce=False),
                },
                'nsentences': NumSamplesDataset(),
                'ntokens': NumelDataset(src_tokens, reduce=True),
            }

            if not self.args.regression_target:
                label_dataset = make_dataset('label', lang,
                                             self.target_dictionary)
                if label_dataset is not None:
                    dataset.update(target=OffsetTokensDataset(
                        StripTokenDataset(
                            label_dataset,
                            id_to_strip=self.target_dictionary.eos(),
                        ),
                        offset=-self.target_dictionary.nspecial,
                    ))
            else:
                label_path = "{0}.label".format(get_path('label', lang, split))
                if os.path.exists(label_path):
                    dataset.update(target=RawLabelDataset([
                        float(x.strip()) for x in open(label_path).readlines()
                    ]))

            nested_dataset = NestedDictionaryDataset(
                dataset,
                sizes=[src_tokens.sizes],
            )

            if self.args.no_shuffle:
                dataset = nested_dataset
            else:
                dataset = SortDataset(
                    nested_dataset,
                    # shuffle
                    sort_order=[shuffle],
                )

            print("| Loaded {0} with #samples: {1}".format(
                split, len(dataset)))
            return dataset
示例#14
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            assert dataset is not None, "could not find dataset: {}".format(
                get_path(type, split))
            return dataset

        src_tokens = make_dataset("input0", self.source_dictionary)
        pos_tokens = make_dataset("input1", self.pos_dictionary)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        label0_dataset = make_dataset("label0", self.label0_dictionary)
        label1_dataset = make_dataset("label1", self.label1_dictionary)

        dataset = {
            "id": IdDataset(),
            "net_input": {
                "src_tokens": RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                    pad_to_length=self._max_positions,
                ),
                "src_lengths": NumelDataset(src_tokens, reduce=False),
            },
            "segments": {
                "seg_tokens": RightPadDataset(
                    pos_tokens,
                    pad_idx=self.pos_dictionary.pad(),
                    pad_to_length=self._max_positions,
                ),
                "seg_lengths": NumelDataset(pos_tokens, reduce=False),
            },
            "target0": RightPadDataset(  # use 1 as padding, will be used to mask out padding when calculating loss
                ReplaceDataset(  # replace eos and existing padding (used when some tokens should not be predicted) with -1
                    OffsetTokensDataset(  # offset tokens to get the targets to the correct range (0,1,2,...)
                        label0_dataset,
                        offset=-self.label0_dictionary.nspecial,
                    ),
                    replace_map={
                        self.label0_dictionary.eos()
                        - self.label0_dictionary.nspecial: -1,
                        self.label0_dictionary.pad()
                        - self.label0_dictionary.nspecial: -1,
                    },
                    offsets=np.zeros(len(label0_dataset), dtype=np.int),
                ),
                pad_idx=-1,
                pad_to_length=self._max_positions,
            ),
            "target1": RightPadDataset(  # use 1 as padding, will be used to mask out padding when calculating loss
                ReplaceDataset(  # replace eos and existing padding (used when some tokens should not be predicted) with -1
                    OffsetTokensDataset(  # offset tokens to get the targets to the correct range (0,1,2,...)
                        label1_dataset,
                        offset=-self.label1_dictionary.nspecial,
                    ),
                    replace_map={
                        self.label1_dictionary.eos()
                        - self.label1_dictionary.nspecial: -1,
                        self.label1_dictionary.pad()
                        - self.label1_dictionary.nspecial: -1,
                    },
                    offsets=np.zeros(len(label1_dataset), dtype=np.int),
                ),
                pad_idx=-1,
                pad_to_length=self._max_positions,
            ),
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )
        logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
        self.datasets[split] = dataset
        return self.datasets[split]
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(os.pathsep)
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        languages = [
            name for name in os.listdir(data_path)
            if os.path.isdir(os.path.join(data_path, name))
        ]
        print("| Training on {0} languages: {1}".format(
            len(languages), languages))
        print("| Language to id mapping: ",
              {lang: id
               for id, lang in enumerate(languages)})

        mask_whole_words = self._get_whole_word_mask()
        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 1,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=self.source_dictionary.eos(),
                break_mode=self.args.sample_break_mode,
            )
            print('| loaded {} blocks from: {}'.format(len(dataset),
                                                       split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())

            src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
                dataset,
                self.source_dictionary,
                pad_idx=self.source_dictionary.pad(),
                mask_idx=self.mask_idx,
                seed=self.args.seed,
                mask_prob=self.args.mask_prob,
                leave_unmasked_prob=self.args.leave_unmasked_prob,
                random_token_prob=self.args.random_token_prob,
                freq_weighted_replacement=self.args.freq_weighted_replacement,
                mask_whole_words=mask_whole_words,
            )

            lang_dataset = NestedDictionaryDataset(
                {
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                    'lang_id':
                    RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
                },
                sizes=[src_dataset.sizes],
            )
            lang_datasets.append(lang_dataset)

        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            dataset_lengths = np.array(
                [len(d) for d in lang_datasets],
                dtype=float,
            )
            sample_probs = self._get_sample_prob(dataset_lengths)
            print(
                "| Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            print(
                "| Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + '_' + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
示例#16
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        print('| loaded {} blocks from: {}'.format(len(dataset), split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_whole_words else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id':
                    IdDataset(),
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
示例#17
0
    def __getitem_cached__(self, seed: int, epoch: int, index: int):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            item = self.dataset[index]
            sz = len(item)

            assert (
                self.mask_idx not in item
            ), "Dataset contains mask_idx (={}), this is not expected!".format(
                self.mask_idx, )

            if self.mask_whole_words is not None:
                word_begins_mask = self.mask_whole_words.gather(0, item)
                word_begins_idx = word_begins_mask.nonzero().view(-1)
                sz = len(word_begins_idx)
                words = np.split(word_begins_mask, word_begins_idx)[1:]
                assert len(words) == sz
                word_lens = list(map(len, words))

            # make sure this is value sequence, which must have ## appeared
            assert '##' in self.vocab
            # if predict value (as classification), then we predict for all registers their actual values
            real_bytes = \
                torch.where((item != self.vocab.index('##')) & (item != self.vocab.bos()) & (item != self.vocab.eos()))[
                    0].cpu().numpy()

            # decide elements to mask
            mask = np.full(sz, False)
            num_mask = int(
                # add a random number for probabilistic rounding
                self.mask_prob * len(real_bytes) /
                float(self.mask_multiple_length) + np.random.rand())

            # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
            mask_idc = np.random.choice(real_bytes, num_mask, replace=False)
            if self.mask_stdev > 0.0:
                lengths = np.random.normal(self.mask_multiple_length,
                                           self.mask_stdev,
                                           size=num_mask)
                lengths = [max(0, int(round(x))) for x in lengths]
                mask_idc = np.asarray(
                    [
                        mask_idc[j] + offset for j in range(len(mask_idc))
                        for offset in range(lengths[j])
                    ],
                    dtype=np.int64,
                )
            else:
                mask_idc = np.concatenate(
                    [mask_idc + i for i in range(self.mask_multiple_length)])
            mask_idc = mask_idc[mask_idc < len(mask)]
            try:
                mask[mask_idc] = True
            except:  # something wrong
                print("Assigning mask indexes {} to mask {} failed!".format(
                    mask_idc, mask))
                raise

            if self.return_masked_tokens:
                # exit early if we're just returning the masked tokens
                # (i.e., the targets for masked LM training)
                if self.mask_whole_words is not None:
                    mask = np.repeat(mask, word_lens)
                new_item = np.full(len(mask), self.pad_idx)
                new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8))
                                      == 1]
                return torch.from_numpy(new_item)

            # decide unmasking and random replacement
            rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
            if rand_or_unmask_prob > 0.0:
                rand_or_unmask = mask & (np.random.rand(sz) <
                                         rand_or_unmask_prob)
                if self.random_token_prob == 0.0:
                    unmask = rand_or_unmask
                    rand_mask = None
                elif self.leave_unmasked_prob == 0.0:
                    unmask = None
                    rand_mask = rand_or_unmask
                else:
                    unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
                    decision = np.random.rand(sz) < unmask_prob
                    unmask = rand_or_unmask & decision
                    rand_mask = rand_or_unmask & (~decision)
            else:
                unmask = rand_mask = None

            if unmask is not None:
                mask = mask ^ unmask

            if self.mask_whole_words is not None:
                mask = np.repeat(mask, word_lens)

            new_item = np.copy(item)
            new_item[mask] = self.mask_idx
            if rand_mask is not None:
                num_rand = rand_mask.sum()
                if num_rand > 0:
                    if self.mask_whole_words is not None:
                        rand_mask = np.repeat(rand_mask, word_lens)
                        num_rand = rand_mask.sum()

                    new_item[rand_mask] = np.random.choice(
                        len(self.vocab),
                        num_rand,
                        p=self.weights,
                    )

            return torch.from_numpy(new_item)
示例#18
0
文件: clr.py 项目: daoyuan14/trex
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset('input0', self.source_dictionary)
        assert input0 is not None, 'could not find dataset: {}'.format(
            get_path('input0', split))
        input1 = make_dataset('input1', self.source_dictionary)
        assert input1 is not None, 'could not find dataset: {}'.format(
            get_path('input1', split))
        assert len(input0) == len(input1), 'input pair different length'

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)
            input1 = PrependTokenDataset(input1, self.args.init_token)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(input0))

        if self.args.truncate_sequence:
            input0 = TruncateDataset(input0, self.args.max_positions)
            input1 = TruncateDataset(input1, self.args.max_positions)

        dataset = {
            'id': IdDataset(),
            'net_input0': {
                'src_tokens':
                RightPadDataset(
                    input0,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths':
                NumelDataset(input0, reduce=False),
            },
            'net_input1': {
                'src_tokens':
                RightPadDataset(
                    input1,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths':
                NumelDataset(input1, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens0': NumelDataset(input0, reduce=True),
            'ntokens1': NumelDataset(input1, reduce=True),
        }

        label_path = "{0}.label".format(get_path('label', split))
        if os.path.exists(label_path):
            dataset.update(target=RawLabelDataset(
                [float(x.strip()) for x in open(label_path).readlines()]))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[np.maximum(input0.sizes, input1.sizes)],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(
            split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
示例#19
0
    def _collate(
            self,
            samples: List[Dict],
            pad_idx: int,
            eos_idx: int
    ):
        """
        Does the heavy lifting for creating a batch from the input list of
        examples. The logic is as follows:
            1. Mask the input blocks. In case has_pair is True then we have 2
               blocks to mask.
            2. Prepend the first masked block tensor with the special token
               used as sentence embedding. Eg: CLS in BERT. This happens
               irrespective of the value of has_pair.
            3. If has_pair is True, then append the first masked block with the
               special separator token (eg: SEP for BERT) and compute segment
               label accordingly. In this case, also append the second masked
               block with this special separator token and compute its segment
               label.
            4. For the targets tensor, prepend and append with padding index
               accordingly.
            5. Concatenate all tensors.
        """
        if len(samples) == 0:
            return {}
        # To ensure determinism, we reset the state of the PRNG after every
        # batch based on the seed and the first id of the batch. This ensures
        # that across epochs we get the same mask for the same example. This
        # is needed for reproducibility and is how BERT does masking
        # TODO: Can we add deteminism without this constraint?
        with data_utils.numpy_seed(self.seed + samples[0]["id"]):
            for s in samples:

                # token range is needed for replacing with random token during
                # masking
                token_range = (self.vocab.nspecial, len(self.vocab))

                # mask according to specified probabilities.
                masked_blk_one, masked_tgt_one = self._mask_block(
                    s["block_one"], self.mask_idx, self.pad_idx, token_range,
                )

                tokens = np.concatenate([
                    [self.classif_token_idx], masked_blk_one
                ])
                targets = np.concatenate([[self.pad_idx], masked_tgt_one])
                segments = np.ones(len(tokens)) * self.segment_id

                # if has_pairs is True then we need to add the SEP token to both
                # the blocks after masking and re-compute segments based on the new
                # lengths.
                if self.has_pairs:
                    tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
                    targets_one = np.concatenate([targets, [self.pad_idx]])

                    masked_blk_two, masked_tgt_two = self._mask_block(
                        s["block_two"], self.mask_idx, self.pad_idx, token_range)
                    tokens_two = np.concatenate(
                        [masked_blk_two, [self.sep_token_idx]])
                    targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])

                    # block + 1 sep + 1 special (CLS)
                    segments_one = np.zeros(len(tokens_one))
                    # block + 1 sep
                    segments_two = np.ones(len(tokens_two))

                    tokens = np.concatenate([tokens_one, tokens_two])
                    targets = np.concatenate([targets_one, targets_two])
                    segments = np.concatenate([segments_one, segments_two])

                s["source"] = torch.LongTensor(tokens)
                s["segment_labels"] = torch.LongTensor(segments)
                s["lm_target"] = torch.LongTensor(targets)

        def merge(key):
            return data_utils.collate_tokens(
                [s[key] for s in samples], pad_idx, eos_idx, left_pad=False
            )
        return {
            "id": torch.LongTensor([s["id"] for s in samples]),
            "ntokens": sum(len(s["source"]) for s in samples),
            "net_input": {
                "src_tokens": merge("source"),
                "segment_labels": merge("segment_labels"),
            },
            "lm_target": merge("lm_target"),
            "sentence_target": torch.LongTensor(
                [s["sentence_target"] for s in samples]
            ) if self.has_pairs else None,
            "nsentences": len(samples),
        }
示例#20
0
文件: trex.py 项目: BwRy/trex
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0

        src_tokens = {}
        tgt_tokens = {}
        tgt_values = {}
        for field in configs.fields:
            split_path = os.path.join(self.args.data, field, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary[field],
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, split_path)
                )

            dataset = maybe_shorten_dataset(
                dataset,
                split,
                self.args.shorten_data_split_list,
                self.args.shorten_method,
                self.args.tokens_per_sample,
                self.args.seed,
            )

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 1,  # one less for <s>
                pad=self.source_dictionary[field].pad(),
                eos=self.source_dictionary[field].eos(),
                break_mode=self.args.sample_break_mode,
            )
            logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset, self.source_dictionary[field].bos())

            if field == configs.static_field:
                src_dataset_code, tgt_dataset_code = MaskTokensDataset.apply_mask(
                    dataset,
                    self.source_dictionary[field],
                    pad_idx=self.source_dictionary[field].pad(),
                    mask_idx=self.mask_idx_dict[field],
                    seed=self.args.seed,
                    mask_prob=self.args.mask_prob,
                    leave_unmasked_prob=self.args.leave_unmasked_prob,
                    random_token_prob=self.args.random_token_prob,
                    freq_weighted_replacement=self.args.freq_weighted_replacement,
                )
                src_tokens[field] = RightPadDataset(
                    src_dataset_code,
                    pad_idx=self.source_dictionary[field].pad()
                )
                tgt_tokens[field] = RightPadDataset(
                    tgt_dataset_code,
                    pad_idx=self.source_dictionary[field].pad()
                )
            elif field in configs.byte_fields:
                src_dataset_value, tgt_dataset_value = MaskValuesDataset.apply_mask(
                    dataset,
                    self.source_dictionary[field],
                    pad_idx=self.source_dictionary[field].pad(),
                    mask_idx=self.mask_idx_dict[field],
                    seed=self.args.seed,
                    mask_prob=self.args.mask_prob,
                    leave_unmasked_prob=self.args.leave_unmasked_prob,
                    random_token_prob=self.args.random_token_prob,
                    freq_weighted_replacement=self.args.freq_weighted_replacement,
                )
                src_tokens[field] = RightPadDataset(
                    src_dataset_value,
                    pad_idx=self.source_dictionary[field].pad()
                )

                # dummy tokens are treated as 1
                # TODO: assert there should not be any dummy tokens here
                tgt_values[field] = BytevalueDataset(tgt_dataset_value, self.source_dictionary[field])
            else:
                src_tokens[field] = RightPadDataset(
                    dataset,
                    pad_idx=self.source_dictionary[field].pad()
                )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_dataset_code))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    "id": IdDataset(),
                    "net_input": {
                        "src_tokens": src_tokens,
                        "src_lengths": NumelDataset(src_dataset_code, reduce=False),
                    },
                    "target": {
                        "tgt_tokens": tgt_tokens,
                        "tgt_values": tgt_values
                    },
                    "nsentences": NumSamplesDataset(),
                    "ntokens": NumelDataset(src_dataset_code, reduce=True),
                },
                sizes=[src_dataset_code.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset_code.sizes,
            ],
        )
示例#21
0
    def load_dataset(self, split, epoch=0, combine=False, data_selector=None):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        print('Loading dataset')
        
        data_path = os.path.join(self.args.data)
        dataset_inst = data_utils.load_indexed_dataset(
            os.path.join(data_path, 'insts', split),
            self.instruction_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        
        dataset_state = data_utils.load_indexed_dataset(
            os.path.join(data_path, 'states', split),
            self.state_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        
        if dataset_inst is None or dataset_state is None:
            raise FileNotFoundError('Dataset not found: {}'.format(split))
    
        dataset_inst = SeqOfSeqDataset(dataset_inst, self.instruction_dictionary)
        dataset_state = SeqOfSeqDataset(dataset_state, self.state_dictionary)
        dataset_pos = IRPositionDataset(os.path.join(data_path, 'pos', split))
        dataset = IRDataset(dataset_inst, dataset_state, dataset_pos)
        
        block_size = self.args.function_length
    
        dataset = IRPadDataset(
            dataset,
            inst_pad_idx=self.instruction_dictionary.pad(),
            state_pad_idx=self.state_dictionary.pad(),
            inst_mask_idx=self.inst_mask_idx,
            state_mask_idx=self.state_mask_idx,
            inst_cls_idx=self.instruction_dictionary.bos(),
            state_cls_idx=self.state_dictionary.bos(),
            smallbert_insts_per_input=self.args.smallbert_insts_per_group,
            smallbert_states_per_input=self.args.smallbert_insts_per_group,
            max_length=block_size,
            inst_pad_length=32,
            state_pad_length=16,
            pair=True,
        )
        
        labels_str = list(map(json.loads, open(os.path.join(data_path, 'label', split + '.txt'))))
        labels = torch.tensor([x - 1 if isinstance(x, int) else int(x.strip()) - 1 for x in labels_str])
        #function_indices = [torch.tensor(json.loads(x)) for x in open(os.path.join(data_path, 'funcs', split + '.txt'))]
        
        #dataset = IRMultiFunctionDataset(dataset, function_indices, self.args.max_functions_per_program)
    
        print('| loaded {} batches from: {} and {}'.format(len(dataset),
            os.path.join(data_path, 'insts', split), os.path.join(data_path, 'states', split)))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.labels[split] = SortDataset(RawLabelDataset(labels), sort_order=[shuffle])
        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id': IdDataset(),
                    'net_input': {
                        'src': dataset,
                    },
                    'target': RawLabelDataset(labels),
                    'indices': RawLabelDataset(torch.arange(len(dataset))),
                    'subset': ListDataset([split for _ in range(len(dataset))])
                },
                sizes=[dataset.sizes],
            ),
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
示例#22
0
    def load_dataset(self,
                     split,
                     epoch=0,
                     combine=False,
                     data_path=None,
                     return_only=False,
                     **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if data_path is None:
            data_path = os.path.join(self.args.data, split + '.jsonl')
        if not os.path.exists(data_path):
            raise FileNotFoundError('Cannot find data: {}'.format(data_path))

        query_tokens = []
        query_masks = []
        query_lengths = []
        candidate_tokens = []
        candidate_masks = []
        candidate_lengths = []

        itr = wsc_utils.winogrande_jsonl_iterator(data_path,
                                                  eval=(split == 'test'))

        for sample in itr:
            sentence, pronoun_span, query, cand_text = sample
            prefix = sentence[:pronoun_span[0]].rstrip()
            suffix = sentence[pronoun_span[1]:]

            leading_space = ' ' if sentence[:pronoun_span[0]].endswith(
                ' ') else ''
            trailing_space = ''

            if query is not None:
                query_toks, query_mask = self.binarize_with_mask(
                    query,
                    prefix,
                    suffix,
                    leading_space,
                    trailing_space,
                )
                query_len = len(query_toks)
            else:
                query_toks, query_mask, query_len = None, None, 0

            query_tokens.append(query_toks)
            query_masks.append(query_mask)
            query_lengths.append(query_len)

            cand_toks, cand_mask = self.binarize_with_mask(
                cand_text,
                prefix,
                suffix,
                leading_space,
                trailing_space,
            )

            candidate_tokens.append(cand_toks)
            candidate_masks.append(cand_mask)
            candidate_lengths.append(cand_toks.size(0))

        query_lengths = np.array(query_lengths)

        def get_pad_dataset_fn(tokens, length, pad_idx):
            return PadDataset(
                ListDataset(tokens, length),
                pad_idx=pad_idx,
                left_pad=False,
            )

        query_tokens = get_pad_dataset_fn(query_tokens, query_lengths,
                                          self.vocab.pad())
        query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)

        candidate_lengths = np.array(candidate_lengths)
        candidate_tokens = get_pad_dataset_fn(candidate_tokens,
                                              candidate_lengths,
                                              self.vocab.pad())
        candidate_masks = get_pad_dataset_fn(candidate_masks,
                                             candidate_lengths, 0)

        dataset = {
            'id': IdDataset(),
            'query_tokens': query_tokens,
            'query_masks': query_masks,
            'candidate_tokens': candidate_tokens,
            'candidate_masks': candidate_masks,
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(query_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[query_lengths],
        )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(query_tokens))
        dataset = SortDataset(
            nested_dataset,
            # shuffle
            sort_order=[shuffle],
        )

        if return_only:
            return dataset

        self.datasets[split] = dataset
        return self.datasets[split]
示例#23
0
    def get_batch_iterator(
        self,
        dataset,
        max_tokens=None,
        max_sentences=None,
        max_positions=None,
        ignore_invalid_inputs=False,
        required_batch_size_multiple=1,
        seed=1,
        num_shards=1,
        shard_id=0,
        num_workers=0,
        epoch=1,
    ):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
            max_sentences (int, optional): max number of sentences in each
                batch (default: None).
            max_positions (optional): max sentence length supported by the
                model (default: None).
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N (default: 1).
            seed (int, optional): seed for random number generator for
                reproducibility (default: 1).
            num_shards (int, optional): shard the data iterator into N
                shards (default: 1).
            shard_id (int, optional): which shard of the data iterator to
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).
            epoch (int, optional): the epoch to start the iterator from
                (default: 1).
        Returns:
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        # For default fairseq task, return same iterator across epochs
        # as datasets are not dynamic, can be overridden in task specific
        # setting.
        if dataset in self.dataset_to_epoch_iter:
            return self.dataset_to_epoch_iter[dataset]

        assert isinstance(dataset, FairseqDataset)

        # initialize the dataset with the correct starting epoch
        dataset.set_epoch(epoch)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        if max_positions is not None:
            print(max_positions)
            indices = self.filter_indices_by_size(
                indices, dataset, max_positions, ignore_invalid_inputs
            )

        # create mini-batches with given size constraints
        batch_sampler = dataset.batch_by_size(
            indices,
            max_tokens=max_tokens,
            max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        # return a reusable, sharded iterator
        epoch_iter = iterators.EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
            num_workers=num_workers,
            epoch=epoch,
            buffer_size=getattr(self.args, 'data_buffer_size', 0),
        )
        self.dataset_to_epoch_iter[dataset] = epoch_iter
        return epoch_iter
示例#24
0
def collate(
    samples,
    pad_idx,
    chunk_width,
    chunk_left_context,
    chunk_right_context,
    label_delay,
    seed,
    epoch,
    pad_to_length=None,
    pad_to_multiple=1,
    src_bucketed=False,
    random_chunking=True,
):
    if len(samples) == 0:
        return {}

    def merge(key, pad_to_length=None):
        if key == "source":
            return speech_utils.collate_frames(
                [s[key] for s in samples],
                0.0,
                pad_to_length=pad_to_length,
                pad_to_multiple=pad_to_multiple,
            )
        elif key == "target":
            return data_utils.collate_tokens(
                [s[key] for s in samples],
                pad_idx=pad_idx,
                eos_idx=None,
                left_pad=False,
                move_eos_to_beginning=False,
                pad_to_length=pad_to_length,
                pad_to_multiple=pad_to_multiple,
            )
        else:
            raise ValueError("Invalid key.")

    def chunking(src_item, tgt_item, tgt_start):
        # make a src chunk in the range [begin_src, end_src)
        begin_src = max(0, tgt_start + label_delay - chunk_left_context)
        # ok if end_src past the end of utterance
        end_src = tgt_start + label_delay + chunk_width + chunk_right_context
        # replication pad if necessary
        left_pad = max(0, chunk_left_context - tgt_start - label_delay)
        right_pad = max(0, end_src - src_item.size(0))
        src_item = src_item[begin_src:end_src]
        if left_pad > 0 or right_pad > 0:
            src_item = F.pad(
                src_item.t().unsqueeze(0),
                (left_pad, right_pad),
                mode="replicate",
            ).squeeze(0).t()

        if tgt_item is not None:
            # make a tgt chunk in the range [begin_tgt, end_tgt)
            begin_tgt = tgt_start
            end_tgt = tgt_start + chunk_width  # ok if past the end of utterance
            # replication pad if necessary
            right_pad = max(0, end_tgt - tgt_item.size(0))
            tgt_item = tgt_item[begin_tgt:end_tgt]
            if right_pad > 0:
                tgt_item = torch.cat(
                    (tgt_item, tgt_item.new_full((right_pad, ), pad_idx)), 0)
        return src_item, tgt_item

    if chunk_width is None or random_chunking:
        if chunk_width is not None:  # usually for chunk-wise train data
            # no need to sort as all chunks have exactly the same length
            for s in samples:
                with data_utils.numpy_seed(seed, epoch, s["id"]):
                    # generate a chunk by sampling the index of its first label
                    f = np.random.randint(s["source"].size(0) - chunk_width +
                                          1)
                s["source"], s["target"] = chunking(s["source"], s["target"],
                                                    f)
        elif label_delay != 0:  # shift source according to label_delay
            if label_delay > 0:
                left_pad, right_pad = 0, label_delay
            else:
                left_pad, right_pad = -label_delay, 0
            for s in samples:
                src_item = s["source"]
                src_item = F.pad(
                    src_item.t().unsqueeze(0),
                    (left_pad, right_pad),
                    mode="replicate",
                ).squeeze(0).t()
                if label_delay > 0:
                    s["source"] = src_item[label_delay:]
                else:
                    s["source"] = src_item[:label_delay]

        if pad_to_length is not None or src_bucketed:
            src_lengths = torch.IntTensor(
                [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples])
        else:
            src_lengths = torch.IntTensor(
                [s["source"].size(0) for s in samples])
        id = torch.LongTensor([s["id"] for s in samples])
        utt_id = [s["utt_id"] for s in samples]
        src_frames = merge(
            "source",
            pad_to_length=pad_to_length["source"]
            if pad_to_length is not None else None,
        )

        target = None
        if samples[0].get("target", None) is not None:
            target = merge(
                "target",
                pad_to_length=pad_to_length["target"]
                if pad_to_length is not None else None,
            )
            ntokens = sum(s["target"].ne(pad_idx).int().sum().item()
                          for s in samples)
        else:
            ntokens = src_lengths.sum().item()

        text = None
        if samples[0].get("text", None) is not None:
            text = [s["text"] for s in samples]

        if chunk_width is None:  # for whole utterances (i.e., no chunking)
            # sort by descending source length
            src_lengths, sort_order = src_lengths.sort(descending=True)
            id = id.index_select(0, sort_order)
            utt_id = [utt_id[i] for i in sort_order.numpy()]
            src_frames = src_frames.index_select(0, sort_order)
            if target is not None:
                target = target.index_select(0, sort_order)
            if text is not None:
                text = [text[i] for i in sort_order.numpy()]

        batch = {
            "id": id,
            "utt_id": utt_id,
            "nsentences": len(samples),
            "ntokens": ntokens,
            "net_input": {
                "src_tokens": src_frames,
                "src_lengths": src_lengths,
            },
            "target": target,
            "text": text,
        }
        return batch
    else:  # sequential chunking, usually for chunk-wise test data
        if pad_to_length is not None or src_bucketed:
            src_lengths = torch.IntTensor(
                [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples])
        else:
            src_lengths = torch.IntTensor(
                [s["source"].size(0) for s in samples])
        id = torch.LongTensor([s["id"] for s in samples])
        utt_id = [s["utt_id"] for s in samples]
        ori_source = [s["source"] for s in samples]
        ori_target = [s["target"] for s in samples]
        text = None
        if samples[0].get("text", None) is not None:
            text = [s["text"] for s in samples]
        max_length = max(src.size(0) for src in ori_source)
        num_chunks = (max_length + chunk_width - 1) // chunk_width
        batches = []
        for k in range(num_chunks):
            f = k * chunk_width
            for i, s in enumerate(samples):
                if f < src_lengths[i].item():
                    s["source"], s["target"] = chunking(
                        ori_source[i], ori_target[i], f)
                else:
                    s["source"] = ori_source[i].new_zeros(
                        chunk_width + chunk_left_context + chunk_right_context,
                        ori_source[i].size(1))
                    s["target"] = (ori_target[i].new_full(
                        (chunk_width, ), pad_idx)
                                   if ori_target[i] is not None else None)
            src_frames = merge(
                "source",
                pad_to_length=pad_to_length["source"]
                if pad_to_length is not None else None,
            )
            src_chunk_lengths = torch.IntTensor(
                [s["source"].size(0) for s in samples])

            target = None
            if samples[0].get("target", None) is not None:
                target = merge(
                    "target",
                    pad_to_length=pad_to_length["target"]
                    if pad_to_length is not None else None,
                )
                ntokens = sum(s["target"].ne(pad_idx).int().sum().item()
                              for s in samples)
            else:
                ntokens = src_lengths.sum().item()

            batch = {
                "id": id,
                "utt_id": utt_id,
                "nsentences": len(samples) if k == 0 else 0,
                "ntokens": ntokens,
                "net_input": {
                    "src_tokens": src_frames,
                    "src_lengths": src_chunk_lengths,
                },
                "target": target,
                "text": text,
            }
            batches.append(batch)
        return batches
    def __getitem__(self, index: int):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            item = self.dataset[index]
            sz = len(item)

            assert self.mask_idx not in item, \
                'Dataset contains mask_idx (={}), this is not expected!'.format(
                    self.mask_idx,
                )

            if self.mask_whole_words is not None:
                word_begins_mask = self.mask_whole_words.gather(0, item)
                word_begins_idx = word_begins_mask.nonzero().view(-1)
                sz = len(word_begins_idx)
                words = np.split(word_begins_mask, word_begins_idx)[1:]
                assert len(words) == sz
                word_lens = list(map(len, words))

            # decide elements to mask
            mask = np.full(sz, False)
            num_mask = int(
                # add a random number for probabilistic rounding
                self.mask_prob * sz + np.random.rand()
            )
            mask[np.random.choice(sz, num_mask, replace=False)] = True

            if self.return_masked_tokens:
                # exit early if we're just returning the masked tokens
                # (i.e., the targets for masked LM training)
                if self.mask_whole_words is not None:
                    mask = np.repeat(mask, word_lens)
                new_item = np.full(len(mask), self.pad_idx)
                new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
                return torch.from_numpy(new_item)

            # decide unmasking and random replacement
            rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
            if rand_or_unmask_prob > 0.0:
                rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
                if self.random_token_prob == 0.0:
                    unmask = rand_or_unmask
                    rand_mask = None
                elif self.leave_unmasked_prob == 0.0:
                    unmask = None
                    rand_mask = rand_or_unmask
                else:
                    unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
                    decision = np.random.rand(sz) < unmask_prob
                    unmask = rand_or_unmask & decision
                    rand_mask = rand_or_unmask & (~decision)
            else:
                unmask = rand_mask = None

            if unmask is not None:
                mask = mask ^ unmask

            if self.mask_whole_words is not None:
                mask = np.repeat(mask, word_lens)

            new_item = np.copy(item)
            new_item[mask] = self.mask_idx
            if rand_mask is not None:
                num_rand = rand_mask.sum()
                if num_rand > 0:
                    if self.mask_whole_words is not None:
                        rand_mask = np.repeat(rand_mask, word_lens)
                        num_rand = rand_mask.sum()

                    new_item[rand_mask] = np.random.choice(
                        len(self.vocab),
                        num_rand,
                        p=self.weights,
                    )

            return torch.from_numpy(new_item)
示例#26
0
    def load_dataset(self,
                     split,
                     epoch=0,
                     combine=False,
                     data_path=None,
                     return_only=False,
                     **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def binarize(s, append_bos=False):
            if self.bpe is not None:
                s = self.bpe.encode(s)
            tokens = self.vocab.encode_line(
                s,
                append_eos=True,
                add_if_not_exist=False,
            ).long()
            if append_bos and self.args.init_token is not None:
                tokens = torch.cat(
                    [tokens.new([self.args.init_token]), tokens])
            return tokens

        if data_path is None:
            data_path = os.path.join(self.args.data, split + '.jsonl')
        if not os.path.exists(data_path):
            raise FileNotFoundError('Cannot find data: {}'.format(data_path))

        src_tokens = [[] for i in range(self.args.num_classes)]
        src_lengths = [[] for i in range(self.args.num_classes)]
        labels = []

        with open(data_path) as h:
            for line in h:
                example = json.loads(line.strip())
                if 'answerKey' in example:
                    label = ord(example['answerKey']) - ord('A')
                    labels.append(label)
                question = example['question']['stem']
                assert len(
                    example['question']['choices']) == self.args.num_classes
                # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
                question = 'Q: ' + question
                question_toks = binarize(question, append_bos=True)
                for i, choice in enumerate(example['question']['choices']):
                    src = 'A: ' + choice['text']
                    src_bin = torch.cat([question_toks, binarize(src)])
                    src_tokens[i].append(src_bin)
                    src_lengths[i].append(len(src_bin))
        assert all(
            len(src_tokens[0]) == len(src_tokens[i])
            for i in range(self.args.num_classes))
        assert len(src_tokens[0]) == len(src_lengths[0])
        assert len(labels) == 0 or len(labels) == len(src_tokens[0])

        for i in range(self.args.num_classes):
            src_lengths[i] = np.array(src_lengths[i])
            src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i])
            src_lengths[i] = ListDataset(src_lengths[i])

        dataset = {
            'id': IdDataset(),
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens[0], reduce=True),
        }

        for i in range(self.args.num_classes):
            dataset.update({
                'net_input{}'.format(i + 1): {
                    'src_tokens':
                    RightPadDataset(
                        src_tokens[i],
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'src_lengths':
                    src_lengths[i],
                }
            })

        if len(labels) > 0:
            dataset.update({'target': RawLabelDataset(labels)})

        dataset = NestedDictionaryDataset(
            dataset,
            sizes=[
                np.maximum.reduce(
                    [src_token.sizes for src_token in src_tokens])
            ],
        )

        with data_utils.numpy_seed(self.args.seed):
            dataset = SortDataset(
                dataset,
                # shuffle
                sort_order=[np.random.permutation(len(dataset))],
            )

        print('| Loaded {} with {} samples'.format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
示例#27
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        print('| loaded {} blocks from: {}'.format(len(dataset), split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        if self.args.mask_whole_words:
            bpe = encoders.build_bpe(self.args)
            assert bpe is not None

            def is_beginning_of_word(i):
                if i < self.source_dictionary.nspecial:
                    # special elements are always considered beginnings
                    return True
                tok = self.source_dictionary[i]
                if tok.startswith('madeupword'):
                    return True
                try:
                    return bpe.is_beginning_of_word(tok)
                except ValueError:
                    return True

            mask_whole_words = torch.ByteTensor(
                list(
                    map(is_beginning_of_word,
                        range(len(self.source_dictionary)))))
        else:
            mask_whole_words = None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id':
                    IdDataset(),
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
示例#28
0
    def get_batch_iterator(
            self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
            ignore_invalid_inputs=False, required_batch_size_multiple=1,
            seed=1, num_shards=1, shard_id=0, num_workers=0,
    ):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
            max_sentences (int, optional): max number of sentences in each
                batch (default: None).
            max_positions (optional): max sentence length supported by the
                model (default: None).
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N (default: 1).
            seed (int, optional): seed for random number generator for
                reproducibility (default: 1).
            num_shards (int, optional): shard the data iterator into N
                shards (default: 1).
            shard_id (int, optional): which shard of the data iterator to
                return (default: 0).
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means the data will be loaded in the main process
                (default: 0).

        Returns:
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        assert isinstance(dataset, FairseqDataset)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        indices = data_utils.filter_by_size(
            indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
        )

        # create mini-batches with given size constraints
        batch_sampler = data_utils.batch_by_size(
            indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        # return a reusable, sharded iterator
        return iterators.EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
            num_workers=num_workers,
        )
示例#29
0
    def load_dataset(self, split, epoch=0, combine=False):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        paths = self.args.data.split(os.pathsep)
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        print("| data_path", data_path)

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(
                path,
                impl=self.args.dataset_impl,
                fix_lua_indexing=True,
                dictionary=self.dictionary,
            )

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                        doc_break_size=1,
                    ))

            print('| {} {} {} examples'.format(data_path, split_k,
                                               len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=self.args.shuffle_dataset,
            seed=self.seed,
        )
示例#30
0
    def get_batch_iterator(self,
                           dataset,
                           assistant=None,
                           max_tokens=None,
                           max_sentences=None,
                           max_positions=None,
                           ignore_invalid_inputs=False,
                           required_batch_size_multiple=1,
                           seed=1,
                           num_shards=1,
                           shard_id=0,
                           batch_method='sentences'):
        """
        Get an iterator that yields batches of data from the given dataset.

        Args:
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_tokens (int, optional): max number of tokens in each batch.
                Default: ``None``
            max_sentences (int, optional): max number of sentences in each
                batch. Default: ``None``
            max_positions (optional): max sentence length supported by the
                model. Default: ``None``
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long. Default: ``False``
            required_batch_size_multiple (int, optional): require batch size to
                be a multiple of N. Default: ``1``
            seed (int, optional): seed for random number generator for
                reproducibility. Default: ``1``
            num_shards (int, optional): shard the data iterator into N
                shards. Default: ``1``
            shard_id (int, optional): which shard of the data iterator to
                return. Default: ``0``

        Returns:
            ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
                given dataset split
        """
        assert isinstance(dataset, FairseqDataset)

        # get indices ordered by example size
        with data_utils.numpy_seed(seed):
            indices = dataset.ordered_indices()

        # filter examples that are too large
        indices = data_utils.filter_by_size(
            indices,
            dataset.size,
            max_positions,
            raise_exception=(not ignore_invalid_inputs),
        )

        # create mini-batches with given size constraints
        if assistant is not None:
            assistant.associate_data(dataset, indices)
        else:
            batch_sampler = data_utils.batch_by_size(
                indices,
                dataset.num_tokens,
                max_tokens=max_tokens,
                max_sentences=max_sentences,
                required_batch_size_multiple=required_batch_size_multiple,
            )

        if assistant is not None:
            # return a reusable, sharded iterator
            return iterators.AssistantEpochBatchIterator(
                dataset=dataset,
                collate_fn=dataset.collater,
                assistant=assistant,
                max_tokens=max_tokens,
                max_sentences=max_sentences,
                required_batch_size_multiple=required_batch_size_multiple,
                shard_num=num_shards,
                shard_id=shard_id,
                batch_method=batch_method,
                seed=seed,
            )

        else:
            # return a reusable, sharded iterator
            return iterators.EpochBatchIterator(
                dataset=dataset,
                collate_fn=dataset.collater,
                batch_sampler=batch_sampler,
                seed=seed,
                num_shards=num_shards,
                shard_id=shard_id,
            )
示例#31
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(key, split):
            return os.path.join(self.args.data, key, split)

        def make_dataset(key, dictionary):
            split_path = get_path(key, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset("input0", self.source_dictionary)
        assert input0 is not None, "could not find dataset: {}".format(
            get_path("input0", split))
        input1 = make_dataset("input1", self.source_dictionary)

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        if input1 is None:
            src_tokens = input0
        else:
            if self.args.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.args.separator_token)

            src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        src_tokens = maybe_shorten_dataset(
            src_tokens,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.max_positions,
            self.args.seed,
        )

        dataset = {
            "id": IdDataset(),
            "net_input": {
                "src_tokens":
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                "src_lengths":
                NumelDataset(src_tokens, reduce=False),
            },
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens, reduce=True),
        }

        if self.args.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad(),
            )
            dataset["net_input"].update(
                prev_output_tokens=prev_tokens_dataset, )

        label_path = "{0}.npz".format(get_path("label", split))
        if os.path.exists(label_path):
            csr_matrix = load_npz(label_path)
            dataset.update(target=CSRLabelDataset(csr_matrix))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(
            split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]