コード例 #1
0
def load_raw_text_dataset(path, load_splits, src=None, dst=None):
    """Loads specified data splits (e.g., test, train or valid) from raw text
    files in the specified folder."""
    if src is None and dst is None:
        # find language pair automatically
        src, dst = infer_language_pair(path, load_splits)
    assert src is not None and dst is not None, 'Source and target languages should be provided'

    guess = dst  #----------------------------------------------------------------------------------------------------------------
    src_dict, dst_dict = load_dictionaries(path, src, dst)
    dataset = LanguageDatasets(src, guess, dst, src_dict, dst_dict)

    # Load dataset from raw text files
    for split in load_splits:
        src_path = os.path.join(path, '{}.{}'.format(split, src))
        guess_path = os.path.join(path, '{}.{}.guess'.format(
            split, guess))  #**************************************************
        dst_path = os.path.join(path, '{}.{}'.format(split, dst))
        dataset.splits[split] = LanguagePairDataset(
            IndexedRawTextDataset(src_path, src_dict),
            IndexedRawTextDataset(guess_path, dst_dict),
            IndexedRawTextDataset(dst_path, dst_dict),
            pad_idx=dataset.src_dict.pad(),
            eos_idx=dataset.src_dict.eos(),
        )
    return dataset
コード例 #2
0
def load_raw_text_dataset(path, load_splits, src=None, dst=None, doctopic=None, embed_dim=512):
    """Loads specified data splits (e.g., test, train or valid) from raw text
    files in the specified folder."""
    # if src is None and dst is None or doctopic is None:
    #     # find language pair automatically
    #     src, dst = infer_language_pair(path, load_splits)
    assert (src is not None) and (dst is not None) and (doctopic is not None), 'Source language, target language and doc topic should be provided'

    src_dict, dst_dict = load_dictionaries(path, src, dst)
    src_lemma_topic_dict = load_src_lemma_topic_dictionaries(path, src)
    
    dataset = LanguageDatasets(src, dst, doctopic, src_dict, dst_dict, src_lemma_topic_dict)

    # Load dataset from raw text files
    for split in load_splits:
        src_path = os.path.join(path, '{}.{}'.format(split, src))
        dst_path = os.path.join(path, '{}.{}'.format(split, dst))
        src_lemma_path = os.path.join(path, '{}.{}-lemma'.format(split, src))        
        doctopic_path = os.path.join(path, '{}.{}'.format(split, doctopic))
        
        dataset.splits[split] = LanguagePairDataset(
            IndexedRawTextDataset(src_path, src_dict),
            IndexedRawTextDataset(dst_path, dst_dict),
            IndexedRawTextDatasetLEMMA(src_lemma_path),
            IndexedRawTextDatasetDOCTOPICS(doctopic_path),
            src_lemma_topic_dict,
            pad_idx=dataset.src_dict.pad(),
            eos_idx=dataset.src_dict.eos(),
            embed_dim=embed_dim,
        )

        # print(dataset.splits[split].__getitem__(0)) 
    return dataset
コード例 #3
0
ファイル: data.py プロジェクト: shashiongithub/fairseq
def load_raw_text_dataset(path, load_splits, src=None, dst=None):
    """Loads specified data splits (e.g., test, train or valid) from raw text
    files in the specified folder."""
    if src is None and dst is None:
        # find language pair automatically
        src, dst = infer_language_pair(path, load_splits)
    assert src is not None and dst is not None, 'Source and target languages should be provided'

    src_dict, dst_dict = load_dictionaries(path, src, dst)
    dataset = LanguageDatasets(src, dst, src_dict, dst_dict)

    # print("Source dictionary len: ", len(src_dict), src_dict.unk_index, src_dict.unk_word)
    # print("Target dictionary len: ", len(dst_dict), dst_dict.unk_index, dst_dict.unk_word)

    # Load dataset from raw text files
    for split in load_splits:
        print(split)
        src_path = os.path.join(path, '{}.{}'.format(split, src))
        src_indexedrawdataset = IndexedRawTextDataset(src_path,
                                                      src_dict,
                                                      is_dst=False)

        dst_path = os.path.join(path, '{}.{}'.format(split, dst))
        dst_indexedrawdataset = IndexedRawTextDataset(
            dst_path,
            dst_dict,
            is_dst=True,
            src_oov_words_list=src_indexedrawdataset.oov_words_list)

        dataset.splits[split] = LanguagePairDataset(
            src_indexedrawdataset,
            dst_indexedrawdataset,
            # IndexedRawTextDataset(src_path, src_dict),
            # IndexedRawTextDataset(dst_path, dst_dict),
            pad_idx=dataset.src_dict.pad(),
            eos_idx=dataset.src_dict.eos(),
        )
    return dataset