Esempio n. 1
0
    def main(self):
        """
        這支程式用來產生decoding用的資料,可以在下方data更改輸入字串,會自動產生模型可以吃的input資料
        """
        data = self.gen_json(
            context=
            '台灣買的到美國牛、澳洲牛,但就是買不到日本牛。 想請問是什麼原因禁止進口日本牛肉? 還有是從什麼時候開始禁止進口日本牛的?',
            discription='為什麼台灣不能進口日本牛肉?')

        out_list = []
        for item in data:
            out_dict = {}
            context = self.go_through_processes_for_context(item['context'])
            discription = self.go_through_processes_for_discription(
                item['discription'])

            print(context)
            print(discription)
            if context and discription:
                out_dict['context'] = context
                out_dict['discription'] = discription
                out_list.append(out_dict)
            else:
                print('無法處理的資料')
        data = self.convert_UNK(out_list)  # 轉換UNK
        # pprint(data)

        # 產生data_convert_example.py可以吃的格式的資料
        self.gen_input_format(data,
                              '../yahoo_knowledge_data/decode/data_ready')
        # 產生input可以吃的資料格式
        text_to_binary('../yahoo_knowledge_data/decode/data_ready',
                       '../yahoo_knowledge_data/decode/decode_data')
Esempio n. 2
0
def split_decode_data():
    file_num = 1
    with open(
            '../yahoo_knowledge_data/decode/ver_5/readable_data_ready') as rf:
        for line in rf:
            with open(
                    '../yahoo_knowledge_data/decode/ver_5/dataset_ready/data_ready_'
                    + str(file_num), 'w') as wf:
                wf.write(line.replace('\n', ''))
            file_num += 1
    file_num = 1
    for item in glob('../yahoo_knowledge_data/decode/ver_5/dataset_ready/*'):
        text_to_binary(
            item, '../yahoo_knowledge_data/decode/ver_5/dataset_input/data_' +
            str(file_num))
        file_num += 1
Esempio n. 3
0
    def get_data(self, description, context):
        data_dict = self.gen_data_dict(description=description, context=context)
        
        data = []
        context = self.go_through_processes_for_context(data_dict['context'])
        description = self.go_through_processes_for_description(data_dict['description'])
        if context and description:
            data_dict['context'] = context
            data_dict['description'] = description
            data.append(data_dict)
        else:
            print('無法處理的資料')
            return None

        data = self.convert_UNK(data) # 轉換UNK
        
        # 產生data_convert_example.py可以吃的格式的資料
        self.gen_input_format(data, './yahoo_knowledge_data/decode/decode_temp')
        # 產生input可以吃的資料格式
        text_to_binary('./yahoo_knowledge_data/decode/decode_temp', './yahoo_knowledge_data/decode/decode_data')
        return data[0]
 def main(self, in_file_path, out_json_name, out_file_name):
     corpus_json = self.get_json(in_file_path)
     with open(out_json_name, 'w') as wf: # 保留一份json檔
         json.dump(corpus_json, wf)
     self.gen_input_data(corpus_json, out_file_name)
     text_to_binary('decoding_corpus/test', '../data/test2')
 def convert_to_bin(self):
     text_to_binary(self.train_path, './train/data')
     text_to_binary(self.valid_path, './valid/data')
def main(unused_argv):
    vocab = data.Vocab(FLAGS.vocab_path, 1000000)
    # Check for presence of required special tokens.
    assert vocab.CheckVocab(data.PAD_TOKEN) > 0
    assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
    assert vocab.CheckVocab(data.SENTENCE_START) > 0
    assert vocab.CheckVocab(data.SENTENCE_END) > 0

    batch_size = 4
    if FLAGS.mode == 'decode':
        batch_size = FLAGS.beam_size

    hps = seq2seq_attention_model.HParams(
        mode=FLAGS.mode,  # train, eval, decode
        min_lr=0.01,  # min learning rate.
        lr=0.15,  # learning rate
        batch_size=batch_size,
        enc_layers=1,
        enc_timesteps=120,
        dec_timesteps=30,
        min_input_len=2,  # discard articles/summaries < than this
        num_hidden=128,  # for rnn cell
        emb_dim=128,  # If 0, don't use embedding
        max_grad_norm=2,
        num_softmax_samples=4096)  # If 0, no sampled softmax.

    batcher = batch_reader.Batcher(FLAGS.data_path,
                                   vocab,
                                   hps,
                                   FLAGS.article_key,
                                   FLAGS.abstract_key,
                                   FLAGS.max_article_sentences,
                                   FLAGS.max_abstract_sentences,
                                   bucketing=FLAGS.use_bucketing,
                                   truncate_input=FLAGS.truncate_input)
    tf.set_random_seed(FLAGS.random_seed)

    if hps.mode == 'train':
        model = seq2seq_attention_model.Seq2SeqAttentionModel(
            hps, vocab, num_gpus=FLAGS.num_gpus)
        _Train(model, batcher)
    elif hps.mode == 'eval':
        model = seq2seq_attention_model.Seq2SeqAttentionModel(
            hps, vocab, num_gpus=FLAGS.num_gpus)
        _Eval(model, batcher, vocab=vocab)
    elif hps.mode == 'decode':
        decode_mdl_hps = hps
        # Only need to restore the 1st step and reuse it since
        # we keep and feed in state for each step's output.
        decode_mdl_hps = hps._replace(dec_timesteps=1)
        model = seq2seq_attention_model.Seq2SeqAttentionModel(
            decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)

        to_build_grapth = True
        p = preprocessing(FLAGS.vocab_path)

        # 舊的decode迴圈
        # while True:
        #     kb_input = input('> ')
        #     if kb_input == 'c':
        #         description_str = input('輸入description > ')
        #         context_str = input('輸入context> ')
        #         input_data = p.get_data(description=description_str, context=context_str)
        #         print('輸入資料:')
        #         pprint(input_data)
        #     elif kb_input == 'q':
        #         break
        #     else:
        #         try:
        #             text_to_binary('yahoo_knowledge_data/decode/ver_5/dataset_ready/data_ready_' + kb_input,
        #                     'yahoo_knowledge_data/decode/decode_data')
        #         except:
        #             print('預設testing data出現錯誤')
        #     decoder = seq2seq_attention_decode.BSDecoder(model, hps, vocab, to_build_grapth)
        #     to_build_grapth = False
        #     decoder.DecodeLoop()

        # 論文用的decode迴圈
        file_num = 1
        while True:
            if file_num % 60 == 0:
                print('已經印60筆')
                break
            try:
                text_to_binary(
                    'yahoo_knowledge_data/decode/ver_5/dataset_ready/data_ready_'
                    + str(file_num), 'yahoo_knowledge_data/decode/decode_data')
            except:
                print('預設testing data出現錯誤')
                break
            decoder = seq2seq_attention_decode.BSDecoder(
                model, hps, vocab, to_build_grapth)
            to_build_grapth = False
            decoder.DecodeLoop()
            print('==================', file_num, '==================')
            file_num += 1
Esempio n. 7
0
def main():
    # with open('../yahoo_knowledge_data/corpus/init_data.json') as rf:
    #     data = json.load(rf)

    # err = 0
    # out_list = []
    # for item in tqdm(data):
    #     out_dict = {}
    #     description = description_processes(item['description'])
    #     context = context_processes(item['context'])
    #     if context and description:
    #         out_dict['context'] = context
    #         out_dict['description'] = description
    #         out_list.append(out_dict)
    #     else:
    #         err += 1

    # with open('../yahoo_knowledge_data/corpus/ver_6/preprocessed_data.json', 'w') as wf:
    #     json.dump(out_list, wf)

    # print('全部資料總共', len(data), '筆')
    # print('前處理無法處理的數量共', err, '筆')
    # print('乾淨資料總計', (len(data) - err), '筆')

    """
    以上做完前處理,為了加速所以先存檔,接著下來用讀檔的比較快,之後也可以串起來一次做完。
    """

    with open('../yahoo_knowledge_data/corpus/ver_6/preprocessed_data.json', 'r') as rf:
        data = json.load(rf)
    
    
    # 過濾掉含有特定字串的item
    data = filter_specific_word(data)
    # 刪除重複context的item
    data = preprocessing.remove_duplicate(data)

    # # sample東西出來看
    # data = preprocessing._sample_data_to_see(data, 100)
    # pprint(data)

    gen = gen_vocab()
    word_count = gen.get_word_count_with_threshold(data, 100) # 用來轉換UNK的counter

    print('==== 開始轉換<UNK> ====')
    data = preprocessing.convert_UNK(word_count, data) # 轉換UNK

    word_count = gen.get_word_count_with_threshold(data, 0) # 這次的word_count有包含UNK
    print('最後版本的vocab是', len(word_count), '個字')

    # 產生vocab
    gen.gen_final_vocab_and_vocab_tsv(word_count, '../yahoo_knowledge_data/vocab/ver_6/vocab')
    
    print('==== 開始分train, valid ====')
    train, valid, test = preprocessing.split_train_valid(data, test_size=.0005, valid_size=.1) # 回傳(train, valid, test)
    print('train size', len(train))
    print('valid size', len(valid))
    print('test size', len(test))

    # 產生data_convert_example.py可以吃的格式的資料
    print('==== 開始產生input data ====')
    preprocessing.gen_input_format(train, '../yahoo_knowledge_data/train/ver_6/readable_data_ready')
    preprocessing.gen_input_format(valid, '../yahoo_knowledge_data/valid/ver_6/readable_data_ready')
    preprocessing.gen_input_format(test, '../yahoo_knowledge_data/decode/ver_6/readable_data_ready')
    text_to_binary('../yahoo_knowledge_data/train/ver_6/readable_data_ready', '../yahoo_knowledge_data/train/ver_6/data')
    text_to_binary('../yahoo_knowledge_data/valid/ver_6/readable_data_ready', '../yahoo_knowledge_data/valid/ver_6/data')
    text_to_binary('../yahoo_knowledge_data/decode/ver_6/readable_data_ready', '../yahoo_knowledge_data/decode/ver_6/data')