Example #1
0
    def learning(cls, total_epoch, n_train, n_valid, n_test, batch_size, window_size, noise_rate, model_file, features_vector, labels_vector,
                 n_hidden1,
                 learning_rate,
                 dropout_keep_rate, early_stop_cost=0.001):
        n_features = len(features_vector) * window_size  # number of features = 17,382 * 10

        log.info('load characters list...')
        log.info('load characters list OK. len: %s' % NumUtil.comma_str(len(features_vector)))
        watch = WatchUtil()

        train_file = os.path.join(KO_WIKIPEDIA_ORG_DIR, 'datasets', 'spelling_error_correction',
                                  'ko.wikipedia.org.dataset.sentences=%s.window_size=%d.train.gz' % (n_train, window_size))
        valid_file = os.path.join(KO_WIKIPEDIA_ORG_DIR, 'datasets', 'spelling_error_correction',
                                  'ko.wikipedia.org.dataset.sentences=%s.window_size=%d.valid.gz' % (n_valid, window_size))
        test_file = os.path.join(KO_WIKIPEDIA_ORG_DIR, 'datasets', 'spelling_error_correction',
                                 'ko.wikipedia.org.dataset.sentences=%s.window_size=%d.test.gz' % (n_test, window_size))

        log.info('train_file: %s' % train_file)
        log.info('valid_file: %s' % valid_file)
        log.info('test_file: %s' % test_file)
        if not os.path.exists(train_file) or not os.path.exists(valid_file) or not os.path.exists(test_file):
            dataset_dir = os.path.dirname(train_file)
            if not os.path.exists(dataset_dir):
                os.makedirs(dataset_dir)

            watch.start('create dataset')  # FIXME: out of memory (1M sentences)
            log.info('create dataset...')

            data_files = (('train', KO_WIKIPEDIA_ORG_TRAIN_SENTENCES_FILE, n_train, train_file, False),
                          ('valid', KO_WIKIPEDIA_ORG_VALID_SENTENCES_FILE, n_valid, valid_file, False),
                          ('test', KO_WIKIPEDIA_ORG_TEST_SENTENCES_FILE, n_test, test_file, False))

            for (name, data_file, total, dataset_file, to_one_hot_vector) in data_files:
                check_interval = 10000
                log.info('check_interval: %s' % check_interval)
                log.info('%s %s total: %s' % (name, os.path.basename(data_file), NumUtil.comma_str(total)))
                log.info('noise_rate: %s' % noise_rate)

                features, labels = [], []
                with gzip.open(data_file, 'rt') as f:
                    for i, line in enumerate(f, 1):
                        if total < i:
                            break

                        if i % check_interval == 0:
                            time.sleep(0.01)  # prevent cpu overload
                            percent = i / total * 100
                            log.info('create dataset... %.1f%% readed. data len: %s. %s' % (percent, NumUtil.comma_str(len(features)), data_file))

                        sentence = line.strip()
                        for start in range(0, len(sentence) - window_size + 1):  # 문자 단위로 노이즈(공백) 생성
                            chars = sentence[start: start + window_size]
                            for idx in range(len(chars)):
                                noised_chars = StringUtil.replace_with_index(chars, ' ', idx)
                                features.append(noised_chars)
                                labels.append(chars)
                                log.debug('create dataset... %s "%s" -> "%s"' % (name, noised_chars, chars))

                # log.info('noise_sampling: %s' % noise_sampling)
                #         for nth_sample in range(noise_sampling): # 초성, 중성, 종성 단위로 노이즈 생성
                #             for start in range(0, len(sentence) - window_size + 1):
                #                 chars = sentence[start: start + window_size]
                #                 noised_chars = SpellingErrorCorrection.encode_noise(chars, noise_rate=noise_rate, noise_with_blank=True)
                #                 if chars == noised_chars:
                #                     continue
                #                 if i % check_interval == 0 and nth_sample == 0:
                #                     log.info('create dataset... %s "%s" -> "%s"' % (name, noised_chars, chars))
                #                 features.append(noised_chars)
                #                 labels.append(chars)

                # print('dataset features:', features)
                # print('dataset labels:', labels)
                dataset = DataSet(features=features, labels=labels, features_vector=features_vector, labels_vector=labels_vector, name=name)
                log.info('dataset save... %s' % dataset_file)
                dataset.save(dataset_file, gzip_format=True, verbose=True)
                log.info('dataset save OK. %s' % dataset_file)
                log.info('dataset: %s' % dataset)

            log.info('create dataset OK.')
            log.info('')
            watch.stop('create dataset')

        watch.start('dataset load')
        log.info('dataset load...')
        train = DataSet.load(train_file, gzip_format=True, verbose=True)

        if n_train >= int('100,000'.replace(',', '')):
            valid = DataSet.load(valid_file, gzip_format=True, verbose=True)
        else:
            valid = DataSet.load(train_file, gzip_format=True, verbose=True)
        log.info('valid.convert_to_one_hot_vector()...')
        valid = valid.convert_to_one_hot_vector(verbose=True)
        log.info('valid.convert_to_one_hot_vector() OK.')

        log.info('train dataset: %s' % train)
        log.info('valid dataset: %s' % valid)
        log.info('dataset load OK.')
        log.info('')
        watch.stop('dataset load')

        X, Y, dropout_keep_prob, train_step, cost, y_hat, accuracy = SpellingErrorCorrection.build_DAE(n_features, window_size, noise_rate, n_hidden1,
                                                                                                       learning_rate, watch)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            check_interval = max(1, min(1000, n_train // 10))
            nth_train, nth_input, total_input = 0, 0, total_epoch * train.size

            log.info('')
            log.info('learn...')
            log.info('total_epoch: %s' % total_epoch)
            log.info('train.size (total features): %s' % NumUtil.comma_str(train.size))
            log.info('check_interval: %s' % check_interval)
            log.info('total_epoch: %s' % total_epoch)
            log.info('batch_size: %s' % batch_size)
            log.info('total_input: %s (total_epoch * train.size)' % total_input)
            log.info('')
            watch.start('learn')
            valid_cost = sys.float_info.max
            for epoch in range(1, total_epoch + 1):
                if valid_cost < early_stop_cost:
                    log.info('valid_cost: %s, early_stop_cost: %s, early stopped.' % (valid_cost, early_stop_cost))
                    break
                for step, (features_batch, labels_batch) in enumerate(train.next_batch(batch_size=batch_size, to_one_hot_vector=True), 1):
                    if valid_cost < early_stop_cost:
                        break

                    nth_train += 1
                    nth_input += features_batch.shape[0]
                    sess.run(train_step, feed_dict={X: features_batch, Y: labels_batch, dropout_keep_prob: dropout_keep_rate})

                    # if nth_train % check_interval == 1:
                    percent = nth_input / total_input * 100
                    valid_cost = sess.run(cost, feed_dict={X: valid.features, Y: valid.labels, dropout_keep_prob: 1.0})
                    log.info('[epoch=%s][%.1f%%] %s cost: %.8f' % (epoch, percent, valid.name, valid_cost))

            watch.stop('learn')
            log.info('learn OK.')
            log.info('')

            log.info('model save... %s' % model_file)
            watch.start('model save...')
            model_dir = os.path.dirname(model_file)
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            saver = tf.train.Saver()
            saver.save(sess, model_file)
            watch.stop('model save...')
            log.info('model save OK. %s' % model_file)

        log.info('')
        log.info('total_epoch: %s' % total_epoch)
        log.info('batch_size: %s' % batch_size)
        log.info('total_input: %s (total_epoch * train.size)' % total_input)
        log.info('')
        log.info(watch.summary())
        log.info('')
Example #2
0
    def dump_corpus(mongo_url, db_name, collection_name, sentences_file, characters_file, info_file, urls_file,
                    train_sentences_file, valid_sentences_file, test_sentences_file,
                    mongo_query=None, limit=None):
        """
        Mongodb에서 문서를 읽어서, 문장 단위로 저장한다. (단 문장안의 단어가 1개 이거나, 한글이 전혀 없는 문장은 추출하지 않는다.)
        :param characters_file:
        :param urls_file:
        :param info_file:
        :param mongo_url: mongodb://~~~
        :param db_name: database name of mongodb
        :param collection_name: collection name of mongodb
        :param sentences_file: *.sentence file
        :param train_sentences_file:
        :param valid_sentences_file:
        :param test_sentences_file:
        :param mongo_query: default={}
        :param limit:
        :return:
        """
        if mongo_query is None:
            mongo_query = {}

        corpus_mongo = MongodbUtil(mongo_url, db_name=db_name, collection_name=collection_name)
        total_docs = corpus_mongo.count()
        log.info('%s total: %s' % (corpus_mongo, NumUtil.comma_str(total_docs)))

        output_dir = os.path.basename(sentences_file)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        with gzip.open(sentences_file, 'wt') as out_f, \
                gzip.open(train_sentences_file, 'wt') as train_f, \
                gzip.open(valid_sentences_file, 'wt') as valid_f, \
                gzip.open(test_sentences_file, 'wt') as test_f, \
                open(info_file, 'wt') as info_f, \
                open(urls_file, 'wt') as urls_f:

            char_set = set()
            n_docs = n_total = n_train = n_valid = n_test = 0
            if limit:
                cursor = corpus_mongo.find(mongo_query, limit=limit)
            else:
                cursor = corpus_mongo.find(mongo_query)

            for i, row in enumerate(cursor, 1):
                if i % 1000 == 0:
                    log.info('%s %.1f%% writed.' % (os.path.basename(sentences_file), i / total_docs * 100))

                sentences = []
                for c in row['content']:
                    sentences.extend(HangulUtil.text2sentences(c['sentences'], remove_only_one_word=True, has_hangul=True))
                # sentences = HangulUtil.text2sentences(row['content'], remove_only_one_word=True, has_hangul=True)

                log.debug('url: %s, len: %s' % (row['url'], len(sentences)))
                if len(sentences) == 0:
                    # log.error(row['content'])
                    continue

                urls_f.write(row['url'])
                urls_f.write('\n')
                n_docs += 1

                for s in sentences:
                    _char_set = set([c for c in s])
                    char_set.update(_char_set)

                    n_total += 1
                    out_f.write(s)
                    out_f.write('\n')

                if len(sentences) >= 10:  # can split
                    test_len = valid_len = len(sentences) // 10
                    # log.info('train: %s, test: %s, valid: %s' % (len(sentences) - test_len - valid_len, test_len, valid_len))
                    for s in sentences[:test_len]:
                        n_test += 1
                        test_f.write(s)
                        test_f.write('\n')
                    for s in sentences[test_len:test_len + valid_len]:
                        n_valid += 1
                        valid_f.write(s)
                        valid_f.write('\n')
                    for s in sentences[test_len + valid_len:]:
                        n_train += 1
                        train_f.write(s)
                        train_f.write('\n')
                else:  # can't split
                    for s in sentences:
                        n_train += 1
                        train_f.write(s)
                        train_f.write('\n')

            char_list = list(char_set)
            char_list.sort()
            log.info('writed to %s...' % characters_file)
            with open(characters_file, 'w') as f:
                for c in char_list:
                    f.write(c)
                    f.write('\n')
            log.info('writed to %s OK.' % characters_file)

            log.info('total docs: %s', NumUtil.comma_str(total_docs))
            log.info('total docs: %s (has hangul sentence)', NumUtil.comma_str(n_docs))
            log.info('total sentences: %s (has hangul sentence)', NumUtil.comma_str(n_total))
            log.info('train: %s', NumUtil.comma_str(n_train))
            log.info('valid: %s', NumUtil.comma_str(n_valid))
            log.info('test: %s', NumUtil.comma_str(n_test))
            log.info('total characters: %s', NumUtil.comma_str(len(char_list)))

            info_f.write('total docs: %s\n' % NumUtil.comma_str(total_docs))
            info_f.write('total docs: %s (has hangul sentence)\n' % NumUtil.comma_str(n_docs))
            info_f.write('total sentences: %s (has hangul sentence)\n' % NumUtil.comma_str(n_total))
            info_f.write('train: %s\n' % NumUtil.comma_str(n_train))
            info_f.write('valid: %s\n' % NumUtil.comma_str(n_valid))
            info_f.write('test: %s\n' % NumUtil.comma_str(n_test))
            info_f.write('total characters: %s\n' % NumUtil.comma_str(len(char_list)))
Example #3
0
                                watch = WatchUtil()
                                watch.start()

                                _test_cost, _W1, _b1, _x_batch, _y_batch, _y_hat_batch = sess.run(
                                    [cost, W1, b1, x, y, y_hat],
                                    feed_dict={
                                        learning_rate: _learning_rate,
                                        use_first_pipeline: True
                                    })

                                log.info('')
                                log.info('W1: %s' % ['%.4f' % i for i in _W1])
                                log.info('b1: %.4f' % _b1)
                                for (x1, x2), _y, _y_hat in zip(
                                        _x_batch, _y_batch, _y_hat_batch):
                                    log.debug(
                                        '%3d + %3d = %4d (y_hat: %4.1f)' %
                                        (x1, x2, _y, _y_hat))
                                log.info('')
                                log.info(
                                    '"%s" test: test_cost: %.8f, %.2f secs (batch_size: %s)'
                                    % (model_name, _test_cost, watch.elapsed(),
                                       batch_size))
                                log.info('')
                            except:
                                log.info(traceback.format_exc())
                            finally:
                                coordinator.request_stop()
                                coordinator.join(
                                    threads)  # Wait for threads to finish.