def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset
Beispiel #2
0
def load_decode_data(path, mydict):

    dataset = data_utils.load_indexed_dataset(
        path,
        mydict,
        'mmap',
        combine=False,
    )
    dataset = PrependTokenDataset(dataset, mydict.bos())
    return dataset
Beispiel #3
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = StripTokenDataset(dataset, self.dictionary.eos())

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample -
            2,  # one less for <s> and one for </s>
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            document_sep_len=0)

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
        dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())

        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_length != 'subword' else None

        self.datasets[split] = DenoisingDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.mask_idx,
            mask_whole_words,
            shuffle=self.args.shuffle_instance,
            seed=self.seed,
            args=self.args)
        print("| Split: {0}, Loaded {1} samples of denoising_dataset".format(
            split,
            len(self.datasets[split]),
        ))
Beispiel #4
0
def load_text_annotations(path, prefix):
    text_data = load_indexed_dataset(
        os.path.join(path, prefix + '.text'),
        None,
        dataset_impl='mmap',
    )
    assert text_data is not None

    annotation_data = np.load(os.path.join(path, prefix + '.annotations.npy'))
    assert annotation_data is not None
    return text_data, annotation_data
Beispiel #5
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0

        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path, self.dictionary, self.args.dataset_impl, combine=combine
        )
        if dataset is None:
            raise FileNotFoundError(
                "Dataset not found: {} ({})".format(split, split_path)
            )

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )

        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
        )

        add_eos_for_other_targets = (
            self.args.sample_break_mode is not None
            and self.args.sample_break_mode != "none"
        )

        self.datasets[split] = self._initialize_dataset(
            dataset=dataset,
            sizes=dataset.sizes,
            src_vocab=self.dictionary,
            tgt_vocab=self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
Beispiel #6
0
    def load_dataset(self, split, shuffle=True, **kwargs):
        """Load a dataset split."""
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(
                data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            return indexed_dataset.dataset_exists(filename,
                                                  impl=self.args.dataset_impl)

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang
        if split_exists(split, src, tgt, src, self.args.data):
            prefix = os.path.join(self.args.data,
                                  '{}.{}-{}.'.format(split, src, tgt))
        elif split_exists(split, tgt, src, src, self.args.data):
            prefix = os.path.join(self.args.data,
                                  '{}.{}-{}.'.format(split, tgt, src))
        else:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, self.args.data))
        src_dataset = data_utils.load_indexed_dataset(prefix + src,
                                                      self.src_dict,
                                                      self.args.dataset_impl)

        # tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict)

        rng = np.random.RandomState(self.args.seed)

        # need to be updated with extractive summarization dataset
        self.datasets[split] = SentsPermAndPredictMaskDataset(
            src_dataset,
            src_dataset.sizes,
            self.src_dict,
            None,
            None,
            self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            shuffle=shuffle,
            is_poniter_net=(self.args.predict_arch == 'pointer_net'),
            max_sent_len=self.args.max_sent_length,
            max_doc_len=self.args.max_doc_length,
            masked_sent_prob=self.args.masked_sent_prob,
            max_predictions_per_doc=self.args.max_predictions_per_doc,
            rng=rng,
            shuffle_prob=self.args.shuffle_prob,
            doc_sizes=None,
            mask_other_sents=eval(self.args.mask_other_sents),
            max_tokens_len=self.args.max_roberta_position,
            fix_ratio=self.args.fix_ratio,
            bert_model=self.args.roberta_model,
        )
Beispiel #7
0
        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            assert dataset is not None, "could not find dataset: {}".format(
                get_path(type, split))
            return dataset
Beispiel #8
0
    def __init__(
        self,
        data_dir,
        split,
        sample_rate,
        max_sample_size=None,
        min_sample_size=0,
        shuffle=True,
        pad=False,
        normalize=False,
        num_buckets=0,
        compute_mask_indices=False,
        **mask_compute_kwargs,
    ):
        super().__init__(
            sample_rate=sample_rate,
            max_sample_size=max_sample_size,
            min_sample_size=min_sample_size,
            shuffle=shuffle,
            pad=pad,
            normalize=normalize,
            compute_mask_indices=compute_mask_indices,
            **mask_compute_kwargs,
        )

        from fairseq.data import data_utils, Dictionary

        self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))

        root_path = os.path.join(data_dir, f"{split}.root")
        if os.path.exists(root_path):
            with open(root_path, "r") as f:
                self.root_dir = next(f).strip()
        else:
            self.root_dir = None

        fnames_path = os.path.join(data_dir, split)
        self.fnames = data_utils.load_indexed_dataset(fnames_path,
                                                      self.fnames_dict)
        lengths_path = os.path.join(data_dir, f"{split}.lengths")

        with open(lengths_path, "r") as f:
            for line in f:
                sz = int(line.rstrip())
                assert (
                    sz >= min_sample_size
                ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
                self.sizes.append(sz)

        self.sizes = np.array(self.sizes, dtype=np.int64)

        self.set_bucket_info(num_buckets)
        logger.info(f"loaded {len(self.fnames)} samples")
Beispiel #9
0
def load_mask_data(path,mydict):#一个大列表,每个item是一个文档矩阵,矩阵里面每个item是一个node的数值  ,for token_id 和
    #print('???',path)
    #from fairseq.data.indexed_dataset import MMapIndexedDataset
    #print('???', MMapIndexedDataset(path) )
    dataset = data_utils.load_indexed_dataset(path,mydict,'mmap',combine=False,)
    #print(dataset.__getitem__(0),dataset.__getitem__(0).shape,len(dataset))
    dataset = TokenBlockDataset(dataset,dataset.sizes,512 - 1,pad=mydict.pad(),eos=mydict.eos(), break_mode='complete_doc',)
    #print(dataset.__getitem__(0),dataset.__getitem__(0).shape,len(dataset))
    dataset = PrependTokenDataset(dataset, mydict.bos())
    #print(dataset.__getitem__(0),dataset.__getitem__(0).shape,len(dataset))
    
    return dataset
    def load_dataset(self, split, shuffle=True):
        """Load a dataset split."""

        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
            return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)


        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang
        if split_exists(split, src, tgt, src, self.args.data):
            prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
        elif split_exists(split, tgt, src, src, self.args.data):
            prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
        else:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))

        # def indexed_dataset(path, dictionary):
        #     if self.args.raw_text:
        #         return IndexedRawTextDataset(path, dictionary)
        #     elif IndexedInMemoryDataset.exists(path):
        #         return IndexedInMemoryDataset(path, fix_lua_indexing=True)
        #     return None

        src_dataset = data_utils.load_indexed_dataset(prefix + src, self.src_dict, self.args.dataset_impl)
        tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, self.tgt_dict, self.args.dataset_impl)

        # need to be updated with extractive summarization dataset
        self.datasets[split] = ExtractSumRobertaLongDataset(
            src_dataset, src_dataset.sizes, self.src_dict,
            tgt_dataset, tgt_dataset.sizes if tgt_dataset else None, self.tgt_dict,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            shuffle=shuffle,
            max_sent_len=self.args.max_sent_length,
            max_doc_len=self.args.max_doc_length,
            mask_other_sents=eval(self.args.mask_other_sents)
        )
Beispiel #11
0
        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))
            return dataset
Beispiel #12
0
    def load_text_object_dataset(self, split, **kwargs):
        objects_dataset = ObjectDataset(self.args.data_dir, split, max_obj=self.args.max_obj)
        span_idxs = self.item2span_idxs(sent_num=objects_dataset.sent_num,
                                        max_src_sent=self.args.max_src_sent)

        text_file = text_bin_file(self.args.data_dir, split)  # os.path.join(self.args.data_dir, split)
        text_dataset = data_utils.load_indexed_dataset(text_file, self.vocab_dict)

        self.datasets[split] = TextObjectDataset(text_dataset=text_dataset,
                                                 image_dataset=objects_dataset,
                                                 vocab_dict=self.vocab_dict,
                                                 span_idxs=span_idxs,
                                                 shuffle=True if split == "train" else False)
Beispiel #13
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )

        # create continuous blocks of tokens.  block_size=511或者512
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        logger.info('loaded {} blocks from: {}'.format(len(dataset),
                                                       split_path))
        s2s_dataset = MaskedLanguagePairDataset.apply_mask(
            dataset,
            dataset.sizes,
            self.source_dictionary,
            shuffle=True,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
        )
        self.datasets[split] = s2s_dataset
    def load_denoise_dataset(self, data_path: str,
                             lang: str) -> FairseqDataset:
        """Classic denoising dataset"""
        dataset = data_utils.load_indexed_dataset(data_path, self.common_dict,
                                                  self.args.dataset_impl)
        noisy_dataset = NoisingDataset(
            dataset,
            self.dictionary,
            seed=1,
            max_word_shuffle_distance=self.args.max_word_shuffle_distance,
            word_dropout_prob=self.args.word_dropout_prob,
            word_blanking_prob=self.args.word_blanking_prob,
        )
        noisy_dataset = PrependTokenDataset(
            noisy_dataset, _lang_token_index(self.dictionary, lang))

        clean_dataset = data_utils.load_indexed_dataset(
            data_path, self.common_dict, self.args.dataset_impl)
        denoising_dataset = self._langpair_dataset(noisy_dataset,
                                                   clean_dataset)
        denoising_dataset = self._prepend_lang_bos_to_target(
            denoising_dataset, lang)
        return denoising_dataset
Beispiel #15
0
    def _load_single_lang_dataset(self, split, epoch):
        loaded_datasets = []

        paths = self.args.data.split(":")
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else "")
            path = os.path.join(data_path, split_k)

            ds = data_utils.load_indexed_dataset(
                path, self.dictionary, self.args.dataset_impl
            )
            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        "Dataset not found: {} ({})".format(split, data_path)
                    )

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample - 1,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                )
            )

            print(
                "| {} {} {} examples".format(
                    data_path, split_k, len(loaded_datasets[-1])
                )
            )

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        return dataset, sizes
def load_binarized_bin(dict, input):
    dictionary = Dictionary.load(dict) if dict is not None else None
    dataset = data_utils.load_indexed_dataset(
        input,
        dictionary,
        default='lazy',
    )

    for tensor_line in dataset:
        if dictionary is None:
            line = ' '.join([str(int(x)) for x in tensor_line])
        else:
            line = dictionary.string(tensor_line)

    return line
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(os.pathsep)
        assert len(paths) > 0

        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path, self.dictionary, self.args.dataset_impl, combine=combine
        )
        if dataset is None:
            raise FileNotFoundError(
                "Dataset not found: {} ({})".format(split, split_path)
            )

        if self.args.truncate_sequence:
            dataset = TruncateDataset(dataset, self.args.tokens_per_sample)

        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
        )

        add_eos_for_other_targets = (
            self.args.sample_break_mode is not None
            and self.args.sample_break_mode != "none"
        )

        self.datasets[split] = MonolingualDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=False if hasattr(self.args, 'lm_eval') and self.args.lm_eval else True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
        )
Beispiel #18
0
        def make_dataset(key, dictionary):
            split_path = get_path(key, split)

            try:
                dataset = data_utils.load_indexed_dataset(
                    split_path,
                    dictionary,
                    combine=combine,
                )
            except Exception as e:
                if "StorageException: [404] Path not found" in str(e):
                    logger.warning(f"dataset {e} not found")
                    dataset = None
                else:
                    raise e
            return dataset
Beispiel #19
0
 def desc_dataset(type, dictionary, relation_desc=None):
     now_path=get_path(type)
     #print(now_path)
     dataset=data_utils.load_indexed_dataset(
         now_path,
         dictionary,
         self.args.dataset_impl,
         combine=combine,
     )
     if self.args.init_token is not None:
         dataset = PrependTokenDataset(dataset, self.args.init_token)
     if relation_desc is not None:
         dataset = ConcatSentencesDataset(dataset, relation_desc)
     dataset = TruncateDataset(dataset, self.args.tokens_per_sample) #???
     dataset = RightPadDataset(dataset, pad_idx=self.source_dictionary.pad())
     return dataset
Beispiel #20
0
    def load_feature_dataset(self, split, **kwargs):
        features_dataset = FeatureDataset(self.args.data_dir, split)
        span_idxs = self.get_span_info(sent_num=features_dataset.sent_num,
                                       split=split)

        text_file = text_bin_file(
            self.args.data_dir,
            split)  # os.path.join(self.args.data_dir, split)
        text_dataset = data_utils.load_indexed_dataset(text_file,
                                                       self.vocab_dict)

        self.datasets[split] = MMITextImageDataset(
            text_dataset=text_dataset,
            image_dataset=features_dataset,
            vocab_dict=self.vocab_dict,
            span_idxs=span_idxs,
            shuffle=True if split == "train" else False)
def get_datasets_from_indexed_filterbanks(data_path, tgt_lang, tgt_dict, split,
                                          dataset_impl, skip_norm,
                                          legacy_audio_fix_lua_indexing):
    """
    Creates a dataset reading precomputed filterbanks adn the corresponding target saved as indexed datasets.
    """
    assert tgt_lang is not None
    prefix = os.path.join(data_path, split)

    src_dataset = FilterBanksDataset(prefix + ".npz", dataset_impl == "cached",
                                     legacy_audio_fix_lua_indexing)
    tgt_dataset = data_utils.load_indexed_dataset(prefix + "." + tgt_lang,
                                                  tgt_dict, dataset_impl)
    return FilterBankToTextDataset(src_dataset,
                                   tgt_dataset,
                                   tgt_dict,
                                   skip_normalization=skip_norm)
Beispiel #22
0
def get_xlco_dataset(args, dataset_path, vocab, mask_idx, combine=False):
    dataset = data_utils.load_indexed_dataset(dataset_path,
                                              vocab,
                                              args.dataset_impl,
                                              combine=combine)

    dataset, _ = MaskTokensDataset.apply_mask(
        dataset,
        vocab=vocab,
        pad_idx=vocab.pad(),
        mask_idx=mask_idx,
        seed=args.seed,
        mask_prob=args.mask_prob,
        mask_whole_words=None,
    )
    dataset = XlcoDataset(dataset, vocab)
    return dataset
Beispiel #23
0
def load_langpair_dataset(data_path,
                          split,
                          src,
                          src_feat_roots,
                          tgt,
                          tgt_dict,
                          dataset_impl,
                          left_pad_source,
                          left_pad_target,
                          max_source_positions,
                          max_target_positions,
                          multilv_args,
                          prepend_bos=False,
                          load_alignments=False,
                          truncate_source=False,
                          use_bucketing=True):

    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))

    src_dataset = load_sign_dataset(prefix + src, src_feat_roots)
    tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict,
                                                  dataset_impl)

    assert len(src_dataset) == len(tgt_dataset)

    logger.info('{} {} {}-{} {} examples'.format(data_path, split, src, tgt,
                                                 len(src_dataset)))

    # if prepend_bos:
    #     assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
    #     src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
    #     tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    align_dataset = None
    return SignLanguagePairDataset(src_dataset,
                                   src_dataset.sizes,
                                   tgt_dataset,
                                   tgt_dataset.sizes,
                                   tgt_dict,
                                   left_pad_source=left_pad_source,
                                   left_pad_target=left_pad_target,
                                   max_source_positions=max_source_positions,
                                   max_target_positions=max_target_positions,
                                   align_dataset=align_dataset,
                                   use_bucketing=use_bucketing,
                                   multilv_args=multilv_args)
Beispiel #24
0
    def _load_dataset_split(self, split, epoch, combine):
        paths = utils.split_paths(self.cfg.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.cfg.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError(
                "Dataset not found: {} ({})".format(split, split_path)
            )

        dataset = StripTokenDataset(dataset, self.dictionary.eos())

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.cfg.shorten_data_split_list,
            self.cfg.shorten_method,
            self.cfg.tokens_per_sample,
            self.cfg.seed,
        )

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.cfg.tokens_per_sample - 2,
            # one less for <s> and one for </s>
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.cfg.sample_break_mode,
            document_sep_len=0,
        )
        logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
        dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
        return dataset
Beispiel #25
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    dictionary = Dictionary.load(args.dict) if args.dict is not None else None
    dataset = data_utils.load_indexed_dataset(
        args.input,
        dictionary,
        dataset_impl=args.dataset_impl,
        default="lazy",
    )

    for tensor_line in dataset:
        if dictionary is None:
            line = " ".join([str(int(x)) for x in tensor_line])
        else:
            line = dictionary.string(tensor_line)

        print(line)
Beispiel #26
0
    def load_dataset(self, split, **kwargs):
        features_dir = os.path.join(self.args.features_dir,
                                    f'{split}-features-{self.args.features}')

        image_ids_file = os.path.join(self.args.captions_dir,
                                      f'{split}-ids.txt')
        image_ids = data.read_image_ids(image_ids_file,
                                        non_redundant=self.scst)

        if self.scst and split == 'valid':
            image_ids = image_ids[:self.args.scst_validation_set_size]

        if self.scst:
            captions_file = os.path.join(self.args.captions_dir,
                                         f'{split}-captions.tok.json')
            captions_ds = data.CaptionsDataset(captions_file, image_ids)
        else:
            captions_file = os.path.join(
                self.args.captions_dir,
                f'{split}-captions.{self.args.captions_lang}')
            captions_ds = data_utils.load_indexed_dataset(
                captions_file, self.captions_dict)

        if self.args.features == 'grid':
            image_ds = data.GridFeaturesDataset(features_dir,
                                                image_ids,
                                                grid_shape=(8, 8))
        elif self.args.features == 'obj':
            image_metadata_file = os.path.join(features_dir, 'metadata.csv')
            image_metadata = data.read_image_metadata(image_metadata_file)
            image_ds = data.ObjectFeaturesDataset(features_dir, image_ids,
                                                  image_metadata)
        else:
            raise ValueError(
                f'Invalid --features option: {self.args.features}')

        self.datasets[split] = data.ImageCaptionDataset(
            img_ds=image_ds,
            cap_ds=captions_ds,
            cap_dict=self.captions_dict,
            scst=self.scst,
            shuffle=True)
Beispiel #27
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if self.args.dataset_from_json:
            raise NotImplementedError
        datasets = []
        for path in self.paths:
            try:
                ds = get_datasets_from_indexed_filterbanks(
                    path,
                    self.args.target_lang,
                    self.tgt_dict,
                    split,
                    self.args.dataset_impl,
                    self.args.skip_normalization,
                    self.args.legacy_audio_fix_lua_indexing)
                if self.training:
                    if self.args.context_type == 'src':
                        context_ds = FilterBanksDataset(
                            os.path.join(path, split) + ".context.npz",
                            self.args.dataset_impl == "cached",
                            self.args.legacy_audio_fix_lua_indexing)
                    else:
                        context_ds = data_utils.load_indexed_dataset(
                            os.path.join(path, split) + ".context." + self.args.target_lang,
                            self.tgt_dict,
                            self.args.dataset_impl)
                    datasets.append(ContextAwareDataset(
                        ds, context_ds, self.tgt_dict, self.args.context_type == 'src'))
                else:
                    datasets.append(ds)
            except Exception:
                logger.warning("Split {} not found in {}. Skipping...".format(split, path))
        assert len(datasets) > 0
        if len(datasets) > 1:
            self.datasets[split] = ConcatDataset(datasets)
        else:
            self.datasets[split] = datasets[0]
Beispiel #28
0
    def load_dataset(self, split, epoch=0, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = self.args.data.split(':')
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
        
        self.datasets[split] = self.build_s2s_dataset(dataset)
Beispiel #29
0
    def load_dataset(lang,
                     lang_dict,
                     prefix,
                     dataset_length,
                     sample_ratios=None):
        """
        Function to load additional dataset and deal with all parameters.
        Easier than copying redudant code for each dataset.
        Requires src_dataset to provide the length and sample_ratios.
        """
        lang_datasets = []
        lang_dataset = data_utils.load_indexed_dataset(prefix + lang,
                                                       lang_dict, dataset_impl)
        if lang_dataset is not None:
            lang_datasets.append(lang_dataset)
        assert dataset_length == len(lang_datasets) or len(lang_datasets) == 0
        if dataset_length == 1:
            lang_dataset = lang_datasets[0] if len(lang_datasets) > 0 else None
        else:
            assert sample_ratios is not None
            if len(lang_datasets) > 0:
                lang_dataset = ConcatDataset(lang_datasets, sample_ratios)
            else:
                lang_dataset = None
        if prepend_bos:
            assert hasattr(src_dict, "bos_index") and hasattr(
                lang_dict, "bos_index")
            if lang_dataset is not None:
                lang_dataset = PrependTokenDataset(lang_dataset,
                                                   lang_dict.bos())
        eos = None
        if append_source_id:
            if lang_dataset is not None:
                lang_dataset = AppendTokenDataset(
                    lang_dataset, lang_dict.index('[{}]'.format(lang)))

        lang_dataset_sizes = lang_dataset.sizes if lang_dataset is not None else None
        return lang_dataset, lang_dataset_sizes
Beispiel #30
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = CodeCompletionDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            split_fn=self.split_fn,
            shuffle=(split != 'test'),
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            append_eos_to_source=True,
            append_eos_to_target=True,
        )

        self.datasets[split] = dataset
        logger.info(
            "Split: {0}, Loaded {1} samples of CodeCompletionDataset".format(
                split,
                len(self.datasets[split]),
            ))