Ejemplo n.º 1
0
    def __init__(self,
                 start_ratio=0.0,
                 end_ratio=0.9,
                 data_size=None,
                 domain='*',
                 sample_rate=1.0):
        # get all data
        if not data_size:
            zh_data, en_data = um_corpus.zh_en(domain)

        # get data according to specific ratio
        else:
            domain_dict = {
                'education': 7,
                'laws': 5,
                'news': 15,
                'science': 6,
                'spoken': 5,
                'subtitles': 6,
                'thesis': 6,
            }
            total = list(map(lambda x: x[1], list(domain_dict.items())))
            total = float(sum(total))

            zh_data = []
            en_data = []
            for domain, val in domain_dict.items():
                tmp_zh_data, tmp_en_data = um_corpus.zh_en(domain)
                sample_size = int(val / total * int(data_size))
                zh_data += tmp_zh_data[:sample_size]
                en_data += tmp_en_data[:sample_size]

        # reproduce the process that nmt would go through in order to get its train set; shuffle the data
        random.seed(self.RANDOM_STATE)
        data = list(zip(zh_data, en_data))
        random.shuffle(data)

        # get the train set
        data = self.__split_data(data, 0.0,
                                 self.NMT_TRAIN_RATIO)  # sample data

        # split dataset
        data = self.__split_data(data, start_ratio, end_ratio)

        if start_ratio == 0. or sample_rate < 1.:
            data = self.sample_data(data, sample_rate)

        self.__src_data, self.__tar_data = list(zip(*data))
Ejemplo n.º 2
0
    def __init__(self, _dataset='cdlm'):
        # initialize variables
        self.__processed_dir_path = create_dir(data_dir, 'un_preprocessed',
                                               _dataset)

        zh_data, en_data = um_corpus.zh_en(get_test=False)
        data = list(zip(zh_data, en_data))

        # shuffle the data
        random.seed(self.RANDOM_STATE)
        random.shuffle(data)

        self.gen_data(data, self.BATCH_SIZE_PER_FILE)
    def __init__(self, _is_train, _dataset='cdlm'):
        # initialize variables
        self.__processed_dir_path = create_dir(data_dir, 'un_preprocessed', _dataset)

        # initialize wmt news loader
        start_ratio = 0.0 if _is_train else zh_en_wmt_news.Loader.PRETRAIN_TRAIN_RATIO
        end_ratio = zh_en_wmt_news.Loader.PRETRAIN_TRAIN_RATIO if _is_train else 1.0
        zh_en_wmt_loader = zh_en_wmt_news.Loader(start_ratio, end_ratio)

        # initialize news commentary loader
        start_ratio = 0.0 if _is_train else zh_en_news_commentary.Loader.PRETRAIN_TRAIN_RATIO
        end_ratio = zh_en_news_commentary.Loader.PRETRAIN_TRAIN_RATIO if _is_train else 1.0
        zh_en_news_commentary_loader = zh_en_news_commentary.Loader(start_ratio, end_ratio)

        # load the data
        zh_data, en_data = zh_en_wmt_loader.data()
        zh_data_2, en_data_2 = zh_en_news_commentary_loader.data()

        # um corpus data is only for training
        if _is_train:
            zh_data_3, en_data_3 = um_corpus.zh_en(get_test=False)

            # combine data
            zh_data += tuple(zh_data_3)
            en_data += tuple(en_data_3)

        # combine data
        zh_data += zh_data_2
        en_data += en_data_2

        # word segmentation for zh_data
        zh_data = utils.pipeline(seg_zh_by_jieba_pipeline, zh_data)

        data = list(zip(zh_data, en_data))

        # shuffle the data
        random.seed(self.RANDOM_STATE)
        random.shuffle(data)

        self.gen_data(data, self.BATCH_SIZE_PER_FILE)
 def __load_from_um_corpus():
     zh_data, en_data = um_corpus.zh_en(get_test=False)
     return list(zip(zh_data, en_data))
Ejemplo n.º 5
0
    },
    {
        'name': 'join_list_token_2_string_with_space_for_src_lan',
        'func': utils.join_list_token_2_string,
        'input_keys': ['input_1', ' '],
        'output_keys': 'input_1',
        'show_dict': {'src_lan': 'input_1'},
    },
]

if __name__ == '__main__':
    from nmt.preprocess.corpus import um_corpus
    from nmt.preprocess.inputs import noise_pl, tfds_share_pl

    # origin_zh_data, origin_en_data = wmt_news.zh_en()
    origin_zh_data, origin_en_data = um_corpus.zh_en()
    params = {
        'vocab_size': 45000,
        'max_src_seq_len': 79,
        'max_tar_seq_len': 98,
    }

    seg_pipeline = seg_zh_by_jieba_pipeline

    print('\n------------------- Encoding -------------------------')
    zh_data, en_data, zh_tokenizer, en_tokenizer = utils.pipeline(
        preprocess_pipeline=seg_pipeline + noise_pl.remove_noise + tfds_share_pl.train_tokenizer + tfds_share_pl.encode_pipeline,
        lan_data_1=origin_zh_data, lan_data_2=origin_en_data, params=params)

    print('\n----------------------------------------------')
    print(zh_data.shape)