Пример #1
0
 def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
     src_dataset = PadDataset(
         TokenBlockDataset(
             src_tokens,
             src_lengths,
             self.args['task']['tokens_per_sample'] - 1,  # one less for <s>
             pad=self.source_dictionary.pad(),
             eos=self.source_dictionary.eos(),
             break_mode='eos',
         ),
         pad_idx=self.source_dictionary.pad(),
         left_pad=False,
     )
     src_dataset = PrependTokenDataset(src_dataset,
                                       self.source_dictionary.bos())
     src_dataset = NestedDictionaryDataset(
         {
             'id': IdDataset(),
             'net_input': {
                 'src_tokens': src_dataset,
                 'src_lengths': NumelDataset(src_dataset, reduce=False),
             },
         },
         sizes=src_lengths,
     )
     if sort:
         src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
     return src_dataset
Пример #2
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            path=split_path,
            dictionary=self.source_dictionary,
            dataset_impl=self.args['dataset']['dataset_impl'],
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args['task']['tokens_per_sample'] - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args['task']['sample_break_mode'],
        )
        LOGGER.info('loaded {} blocks from: {}'.format(len(dataset),
                                                       split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args['task']['mask_whole_words'] else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args['common']['seed'],
            mask_prob=self.args['task']['mask_prob'],
            leave_unmasked_prob=self.args['task']['leave_unmasked_prob'],
            random_token_prob=self.args['task']['random_token_prob'],
            freq_weighted_replacement=self.args['task']
            ['freq_weighted_replacement'],
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args['common']['seed'] + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id':
                    IdDataset(),
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
def load_masked_code_docstring_dataset_roberta(args,
                                               epoch,
                                               data_path,
                                               split,
                                               src,
                                               src_dict,
                                               tgt,
                                               tgt_dict,
                                               combine,
                                               dataset_impl,
                                               upsample_primary,
                                               left_pad_source,
                                               left_pad_target,
                                               max_source_positions,
                                               max_target_positions,
                                               prepend_bos=False,
                                               load_alignments=False,
                                               truncate_source=False,
                                               append_source_id=False):
    source_path = os.path.join(data_path, '{}.code'.format(split))
    target_path = os.path.join(data_path, '{}.docstring'.format(split))

    # source_dataset
    source_dataset = data_utils.load_indexed_dataset(source_path,
                                                     'text',
                                                     src_dict,
                                                     tokenizer=None,
                                                     dataset_impl=dataset_impl)
    if source_dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, source_path))
    # target_dataset
    target_dataset = data_utils.load_indexed_dataset(target_path,
                                                     'text',
                                                     tgt_dict,
                                                     tokenizer=None,
                                                     dataset_impl=dataset_impl)
    if target_dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, target_path))

    # concate dataset
    dataset = ConcatSentencesDataset([source_dataset, target_dataset])
    # create continuous blocks of tokens
    dataset = TokenBlockDataset(
        dataset,
        dataset.sizes,
        args['task']['tokens_per_sample'] - 1,  # one less for <s>
        pad=src_dict.pad(),
        eos=src_dict.eos(),
        break_mode=args['task']['sample_break_mode'],
    )
    # LOGGER.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

    # # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
    dataset = PrependTokenDataset(dataset,
                                  src_dict.bos())  # .source_dictionary.bos()
    #
    # # create masked input and targets
    mask_whole_words = get_whole_word_mask(args, src_dict) \
        if args['task']['mask_whole_words'] else None

    src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
        dataset,
        src_dict,
        pad_idx=src_dict.pad(),
        mask_idx=src_dict.index(constants.T_MASK),  # self.mask_idx,
        seed=args['common']['seed'],
        mask_prob=args['task']['mask_prob'],
        leave_unmasked_prob=args['task']['leave_unmasked_prob'],
        random_token_prob=args['task']['random_token_prob'],
        freq_weighted_replacement=args['task']['freq_weighted_replacement'],
        mask_whole_words=mask_whole_words,
    )

    with data_utils.numpy_seed(args['common']['seed'] + epoch):
        shuffle = np.random.permutation(len(src_dataset))

    return SortDataset(
        NestedDictionaryDataset(
            {
                'id':
                IdDataset(),
                'net_input': {
                    'src_tokens':
                    PadDataset(
                        src_dataset,
                        pad_idx=src_dict.pad(),
                        left_pad=False,
                    ),
                    'src_lengths':
                    NumelDataset(src_dataset, reduce=False),
                },
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=src_dict.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            sizes=[src_dataset.sizes],
        ),
        sort_order=[
            shuffle,
            src_dataset.sizes,
        ],
    )
Пример #4
0
def load_masked_traverse_dataset_roberta(
    args,
    epoch,
    data_path,
    split,
    source_dictionary,
    combine,
):
    split_path = os.path.join(data_path, '{}.ast_trav_df'.format(split))
    dataset = data_utils.load_indexed_dataset(
        path=split_path,
        dictionary=source_dictionary,
        dataset_impl=args['dataset']['dataset_impl'],
        combine=combine,
    )
    if dataset is None:
        raise FileNotFoundError('Dataset not found: {} ({})'.format(
            split, split_path))

    # # create continuous blocks of tokens
    # dataset = TokenBlockDataset(
    #     dataset,
    #     dataset.sizes,
    #     args['task']['tokens_per_sample'] - 1,  # one less for <s>
    #     pad=source_dictionary.pad(),
    #     eos=source_dictionary.eos(),
    #     break_mode=args['task']['sample_break_mode'],
    # )
    # LOGGER.info('loaded {} blocks from: {}'.format(len(dataset), split_path))

    # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
    dataset = PrependTokenDataset(
        dataset, source_dictionary.bos())  # .source_dictionary.bos()

    # create masked input and targets
    mask_whole_words = get_whole_word_mask(args, source_dictionary) \
        if args['task']['mask_whole_words'] else None

    src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
        dataset,
        source_dictionary,
        pad_idx=source_dictionary.pad(),
        mask_idx=source_dictionary.index(constants.MASK),  # self.mask_idx,
        seed=args['common']['seed'],
        mask_prob=args['task']['mask_prob'],
        leave_unmasked_prob=args['task']['leave_unmasked_prob'],
        random_token_prob=args['task']['random_token_prob'],
        freq_weighted_replacement=args['task']['freq_weighted_replacement'],
        mask_whole_words=mask_whole_words,
    )

    with data_utils.numpy_seed(args['common']['seed'] + epoch):
        shuffle = np.random.permutation(len(src_dataset))

    return SortDataset(
        NestedDictionaryDataset(
            {
                'id':
                IdDataset(),
                'net_input': {
                    'src_tokens':
                    PadDataset(
                        src_dataset,
                        pad_idx=source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'src_lengths':
                    NumelDataset(src_dataset, reduce=False),
                },
                'target':
                PadDataset(
                    tgt_dataset,
                    pad_idx=source_dictionary.pad(),
                    left_pad=False,
                ),
                'nsentences':
                NumSamplesDataset(),
                'ntokens':
                NumelDataset(src_dataset, reduce=True),
            },
            sizes=[src_dataset.sizes],
        ),
        sort_order=[
            shuffle,
            src_dataset.sizes,
        ],
    )