Ejemplo n.º 1
0
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict,
    dataset_impl,

    left_pad_source, left_pad_target,
    max_source_positions, max_target_positions,
    prepend_bos=False, load_alignments=False,
    truncate_source=False, append_source_id=False,
    truncate_target=False,
    append_eos_to_target=False,
    portion=None,
):
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict)

    if portion is not None and split == 'train':
        LOGGER.info('set {}.{} portion to {}'.format(split, src, portion))
        src_dataset = PortionDataset(src_dataset, portion)

    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict)
    if truncate_target:
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt, max_target_positions))
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    if portion is not None and split == 'train':
        LOGGER.info('set {}.{} portion to {}'.format(split, tgt, portion))
        tgt_dataset = PortionDataset(tgt_dataset, portion)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset), src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))
    return GraphLanguagePairDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset_sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None, eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,

    )
Ejemplo n.º 2
0
 def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
     src_dataset = PadDataset(
         TokenBlockDataset(
             src_tokens,
             src_lengths,
             self.args['task']['tokens_per_sample'] - 1,  # one less for <s>
             pad=self.source_dictionary.pad(),
             eos=self.source_dictionary.eos(),
             break_mode='eos',
         ),
         pad_idx=self.source_dictionary.pad(),
         left_pad=False,
     )
     src_dataset = PrependTokenDataset(src_dataset,
                                       self.source_dictionary.bos())
     src_dataset = NestedDictionaryDataset(
         {
             'id': IdDataset(),
             'net_input': {
                 'src_tokens': src_dataset,
                 'src_lengths': NumelDataset(src_dataset, reduce=False),
             },
         },
         sizes=src_lengths,
     )
     if sort:
         src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
     return src_dataset
Ejemplo n.º 3
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = StripTokenDataset(dataset, self.dictionary.eos())

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample -
            2,  # one less for <s> and one for </s>
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            document_sep_len=0)

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
        dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())

        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_length != 'subword' else None

        self.datasets[split] = DenoisingDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.mask_idx,
            mask_whole_words,
            shuffle=self.args.shuffle_instance,
            seed=self.seed,
            args=self.args)
        LOGGER.info(
            "Split: {0}, Loaded {1} samples of denoising_dataset".format(
                split,
                len(self.datasets[split]),
            ))
Ejemplo n.º 4
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        # paths = self.args.data.split(':')
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        # split_path = os.path.join(data_path, split)

        if self.langs is None:
            languages = sorted([
                name for name in os.listdir(data_path)
                if os.path.isdir(os.path.join(data_path, name))
            ])
        else:
            languages = self.langs  # .split(',')
            # for name in languages:
            #     assert os.path.exists(os.path.join(data_path, name)), FileNotFoundError(os.path.join(data_path, name))

        LOGGER.info("| Training on {0} languages: {1}".format(
            len(languages), languages))
        LOGGER.info("| Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
        lang_datasets = []
        for language in languages:
            # split_path = os.path.join(data_path, language, split)
            if language == 'docstring':
                split_path = os.path.join(data_path, language,
                                          f"{split}.docstring.spm")
            else:
                split_path = os.path.join(data_path, language,
                                          f"{split}.code.spm")
            # split_path = os.path.join(data_path, language, f"{split}.spm.{language}")
            # dataset = data_utils.load_indexed_dataset(
            #     split_path,
            #     self.source_dictionary,
            #     self.args['dataset']['dataset_impl'],
            #     combine=combine,
            # )
            dataset = load_lang_dataset_denoising(
                path=split_path,
                impl=self.args['dataset']['dataset_impl'],
                dict=self.source_dictionary)

            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(dataset, self.source_dictionary.eos()),
                    self.args['task']['max_source_positions'] -
                    3),  # <lang>, <bos>, <eos>
                token=self.source_dictionary.eos(),
            )

            end_token = self.source_dictionary.index('[{}]'.format(language)) \
                if self.args['task']['add_lang_token'] else self.source_dictionary.eos()

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args['task']['tokens_per_sample'] -
                2,  # one less for <s> and one for </s>
                pad=self.source_dictionary.pad(),
                eos=end_token,
                break_mode=self.args['task']['sample_break_mode'],
                document_sep_len=0,
            )
            LOGGER.info('| loaded {} blocks from: {}'.format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())
            dataset = AppendTokenDataset(dataset, end_token)

            lang_dataset = DenoisingDataset(
                dataset,
                dataset.sizes,
                self.dictionary,
                self.mask_idx,
                mask_whole_words,
                shuffle=self.args['dataset']['shuffle_instance'],
                seed=self.seed,
                args=self.args,
                eos=None if not self.args['task']['add_lang_token'] else
                self.source_dictionary.index('[{}]'.format(language)),
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        LOGGER.info('| loaded total {} blocks for all languages'.format(
            dataset_lengths.sum(), ))
        if split == self.args['dataset']['train_subset']:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            LOGGER.info(
                "| Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            LOGGER.info(
                "| Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args['common']['seed'],
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets, )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            # for lang_id, lang_dataset in enumerate(lang_datasets):
            #     split_name = split + '_' + languages[lang_id]
            #     lang_splits.append(split_name)
            #     self.datasets[split_name] = lang_dataset

            if split in self.args['dataset']['valid_subset']:
                self.args['dataset']['valid_subset'] = self.args['dataset'][
                    'valid_subset'].replace(split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args['common']['seed'] + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
Ejemplo n.º 5
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
    append_eos_to_target=False,
):
    # load source dataset
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path,
                                impl=dataset_impl,
                                dict=src_dict)

    if truncate_source:
        # sntn => sntn[:max_source_positions]
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path,
                                impl=dataset_impl,
                                dict=tgt_dict)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    # align_dataset = None
    # if load_alignments:
    #     align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
    #     if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
    #         align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset),
                                                     src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     tgt_path))
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,
        # shuffle=False,  # debug
    )
Ejemplo n.º 6
0
def load_masked_traverse_dataset_roberta(
    args,
    epoch,
    data_path,
    split,
    source_dictionary,
    combine,
):
    split_path = os.path.join(data_path, '{}.ast_trav_df'.format(split))
    dataset = data_utils.load_indexed_dataset(
        path=split_path,
        dictionary=source_dictionary,
        dataset_impl=args['dataset']['dataset_impl'],
        combine=combine,
    )
    if dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, split_path))

    # # create continuous blocks of tokens
    # dataset = TokenBlockDataset(
    #     dataset,
    #     dataset.sizes,
    #     args['task']['tokens_per_sample'] - 1,  # one less for <s>
    #     pad=source_dictionary.pad(),
    #     eos=source_dictionary.eos(),
    #     break_mode=args['task']['sample_break_mode'],
    # )
    # LOGGER.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

    # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
    dataset = PrependTokenDataset(
        dataset, source_dictionary.bos())  # .source_dictionary.bos()

    # create masked input and targets
    mask_whole_words = get_whole_word_mask(args, source_dictionary) \
        if args['task']['mask_whole_words'] else None

    src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
        dataset,
        source_dictionary,
        pad_idx=source_dictionary.pad(),
        mask_idx=source_dictionary.index(constants.MASK),  # self.mask_idx,
        seed=args['common']['seed'],
        mask_prob=args['task']['mask_prob'],
        leave_unmasked_prob=args['task']['leave_unmasked_prob'],
        random_token_prob=args['task']['random_token_prob'],
        freq_weighted_replacement=args['task']['freq_weighted_replacement'],
        mask_whole_words=mask_whole_words,
    )

    with data_utils.numpy_seed(args['common']['seed'] + epoch):
        shuffle = np.random.permutation(len(src_dataset))

    return SortDataset(
        NestedDictionaryDataset(
            {
                'id':
                IdDataset(),
                'net_input': {
                    'src_tokens':
                    PadDataset(
                        src_dataset,
                        pad_idx=source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'src_lengths':
                    NumelDataset(src_dataset, reduce=False),
                },
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=source_dictionary.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            sizes=[src_dataset.sizes],
        ),
        sort_order=[
            shuffle,
            src_dataset.sizes,
        ],
    )
Ejemplo n.º 7
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=True,
    append_eos=True,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
):
    # truncate sentence for prepend <bos> and append <eos>
    max_target_positions -= int(prepend_bos) + int(append_eos)

    # load source dataset
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path,
                                impl=dataset_impl,
                                dict=src_dict)

    if truncate_source:
        # sntn => sntn[:max_source_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, src,
                                                  max_source_positions))
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path,
                                impl=dataset_impl,
                                dict=tgt_dict)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt,
                                                  max_target_positions))
        tgt_dataset = TruncateDataset(
            tgt_dataset, max_target_positions)  # 2 for BOS and EOS
    # sntn[:max_target_positions] => <bos> sntn[:max_target_positions]
    if prepend_bos:
        tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos())
    if append_eos:
        tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos())
    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    # load tgt ids
    tgt_ids_path = os.path.join(data_path, '{}.id'.format(split))
    tgt_ids = _load_ids(tgt_ids_path)

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset),
                                                     src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset),
                                                     tgt_path))
    return BELanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        tgt_ids=tgt_ids,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        bos=src_dict.bos(),
        eos=src_dict.eos(),
        # shuffle=True,
        shuffle=False,  # debug
    )
Ejemplo n.º 8
0
def load_langpair_dataset(
    data_path,
    split,
    domains,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    append_eos_to_target=False,
):
    def split_exists(split, src, data_path, domain):
        filename = os.path.join(data_path, domain,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for dm in domains:
        # load datasets of src domains
        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')

            # infer langcode
            if split_exists(split_k, src, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , src, tgt
            elif split_exists(split_k, tgt, data_path, dm):
                prefix = os.path.join(
                    data_path, dm, '{}.'.format(split_k))  # {}-{}. , tgt, src

            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            src_dataset = data_utils.load_indexed_dataset(
                prefix + src, src, src_dict, dataset_impl)
            if truncate_source:
                src_dataset = AppendTokenDataset(
                    TruncateDataset(
                        StripTokenDataset(src_dataset, src_dict.eos()),
                        max_source_positions - 1,
                    ),
                    src_dict.eos(),
                )
            src_datasets.append(src_dataset)

            tgt_dataset = data_utils.load_indexed_dataset(
                prefix + tgt, tgt, tgt_dict, dataset_impl)
            if tgt_dataset is not None:
                tgt_datasets.append(tgt_dataset)

            if not combine:
                break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(src_dataset,
                                         src_dict.index('[{}]'.format(src)))
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    # align_dataset = None
    # if load_alignments:
    #     align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
    #     if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
    #         align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,  # TODO debug: shuffle=False
    )
Ejemplo n.º 9
0
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    dataset_impl,
    # combine, dataset_impl, upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=True,
    append_eos=True,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    truncate_target=False,
    # lifelong learning
    prev_tasks=[],
    cur_task=None,
    sample_portion=None,
):
    # truncate sentence for prepend <bos> and append <eos>
    max_target_positions -= int(prepend_bos) + int(append_eos)

    # load source dataset
    src_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, src))
    src_dataset = [
        _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict)
    ]
    # load previous tasks
    if len(prev_tasks
           ) > 0 and cur_task is not None and sample_portion is not None:
        sample_size_per_task = int(
            len(src_dataset[0]) * sample_portion // len(prev_tasks))
    else:
        sample_size_per_task = -1
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, src))
            p_dataset = _load_dataset(p_path, dataset_impl, src_dict)
            src_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    src_dataset = ConcatDataset(src_dataset)
    # truncate dataset
    if truncate_source:
        # sntn => sntn[:max_source_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, src,
                                                  max_source_positions))
        src_dataset = TruncateDataset(src_dataset, max_source_positions)

    # load target dataset
    tgt_path = os.path.join(data_path, cur_task, '{}.{}'.format(split, tgt))
    tgt_dataset = [
        _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict)
    ]
    if sample_size_per_task > 0:
        for p_task in prev_tasks:
            p_path = os.path.join(data_path, p_task,
                                  '{}.{}'.format(split, tgt))
            p_dataset = _load_dataset(p_path, dataset_impl, tgt_dict)
            tgt_dataset.append(
                SliceDataset(p_dataset, end=sample_size_per_task))
    tgt_dataset = ConcatDataset(tgt_dataset)
    if truncate_target:
        # sntn => sntn[:max_target_positions]
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt,
                                                  max_target_positions))
        tgt_dataset = TruncateDataset(
            tgt_dataset, max_target_positions)  # 2 for BOS and EOS
    # sntn[:max_target_positions] => <bos> sntn[:max_target_positions]
    if prepend_bos:
        tgt_dataset = PrependTokenDataset(tgt_dataset, token=tgt_dict.bos())
    if append_eos:
        tgt_dataset = AppendTokenDataset(tgt_dataset, token=tgt_dict.eos())
    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    assert len(src_dataset) == len(tgt_dataset), (len(src_dataset),
                                                  len(tgt_dataset))
    LOGGER.info('loaded {} examples from: [{}](current task) + {}(previous tasks)'. \
                format(len(src_dataset), cur_task, prev_tasks))
    return BELanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None,
        bos=src_dict.bos(),
        eos=src_dict.eos(),
        shuffle=(split == 'train'),
    )
Ejemplo n.º 10
0
def load_langpair_dataset(
    args,
    programming_langs,
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    is_distill=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    topk_idxs = []
    topk_probs = []
    expert_scores = []
    dataset_ids = []
    lng_borders = [0]
    is_train = split == 'train'

    for ds_idx, program_lang in enumerate(programming_langs):
        lang_data_path = os.path.join(data_path, program_lang)

        split_k = split
        # infer langcode
        if split_exists(split_k, src, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, lang_data_path):
            prefix = os.path.join(lang_data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src
        else:
            raise NotImplementedError('No data in {}'.format(lang_data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, src,
                                                      src_dict, dataset_impl)
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)
        length = len(src_dataset)
        lng_borders.append(lng_borders[-1] + length)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt,
                                                      tgt_dict, dataset_impl)
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        for i in range(length):
            dataset_ids.append(ds_idx)

        if is_distill and is_train:  # distill only for train
            path = '{}_{}_{}_topk_idx'.format(lang_data_path, src, tgt)
            topk_idxs.append(TeacherOutputDataset(path))
            path = '{}_{}_{}_topk_prob'.format(lang_data_path, src, tgt)
            topk_probs.append(TeacherOutputDataset(path))
            expert_bleu = os.path.join(
                data_path,
                'expert_bleu_{}_{}_{}.json'.format(program_lang, src, tgt))
            expert_bleu = json.load(open(expert_bleu))
            expert_scores.append(expert_bleu[f"bleu_{program_lang}"])

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    sample_ratios = [1] * len(src_datasets)
    sample_ratios[0] = upsample_primary
    src_dataset = ConcatDataset(src_datasets, sample_ratios)
    if len(tgt_datasets) > 0:
        tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
    else:
        tgt_dataset = None

    LOGGER.info('src data: {}, tgt data: {}'.format(len(src_dataset),
                                                    len(tgt_dataset)))

    if is_distill and is_train:  # distill only for train
        topk_idx_dataset = ConcatDataset(topk_idxs)
        topk_probs_dataset = ConcatDataset(topk_probs)
        assert len(topk_probs_dataset) == len(src_dataset), (
            len(topk_probs_dataset), len(src_dataset))
        assert len(topk_idx_dataset) == len(src_dataset)
    else:
        topk_idx_dataset = None
        topk_probs_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(
            tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    return UniversalDataset(
        args,
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        dataset_ids=dataset_ids,
        lng_borders=lng_borders,
        dataset_names=programming_langs,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        topk_idxs=topk_idx_dataset,
        topk_probs=topk_probs_dataset,
        expert_scores=expert_scores,
        is_train=is_train,
    )
def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    load_alignments=False,
    truncate_source=False,
    truncate_target=False,
):
    def split_exists(split, src, data_path):
        filename = os.path.join(data_path,
                                '{}.{}'.format(split,
                                               src))  # -{}.{} , tgt, lang
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else '')

        # infer langcode
        if split_exists(split_k, src, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , src, tgt
        elif split_exists(split_k, tgt, data_path):
            prefix = os.path.join(data_path,
                                  '{}.'.format(split_k))  # {}-{}. , tgt, src

        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, data_path))

        src_dataset = data_utils.load_indexed_dataset(prefix + src, 'text',
                                                      src_dict, dataset_impl)
        if truncate_source and max_source_positions:
            src_dataset = TruncateDataset(src_dataset, max_source_positions)
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, 'text',
                                                      tgt_dict, dataset_impl)
        if truncate_target and max_target_positions:
            tgt_dataset = PrependTokenDataset(
                AppendTokenDataset(TruncateDataset(tgt_dataset,
                                                   max_target_positions - 2),
                                   token=tgt_dict.eos()), tgt_dict.bos())

        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        shuffle=True,  # TODO debug: shuffle=False
    )