def learning(cls, total_epoch, n_train, n_valid, n_test, batch_size, window_size, noise_rate, model_file, features_vector, labels_vector, n_hidden1, learning_rate, dropout_keep_rate, early_stop_cost=0.001): n_features = len(features_vector) * window_size # number of features = 17,382 * 10 log.info('load characters list...') log.info('load characters list OK. len: %s' % NumUtil.comma_str(len(features_vector))) watch = WatchUtil() train_file = os.path.join(KO_WIKIPEDIA_ORG_DIR, 'datasets', 'spelling_error_correction', 'ko.wikipedia.org.dataset.sentences=%s.window_size=%d.train.gz' % (n_train, window_size)) valid_file = os.path.join(KO_WIKIPEDIA_ORG_DIR, 'datasets', 'spelling_error_correction', 'ko.wikipedia.org.dataset.sentences=%s.window_size=%d.valid.gz' % (n_valid, window_size)) test_file = os.path.join(KO_WIKIPEDIA_ORG_DIR, 'datasets', 'spelling_error_correction', 'ko.wikipedia.org.dataset.sentences=%s.window_size=%d.test.gz' % (n_test, window_size)) log.info('train_file: %s' % train_file) log.info('valid_file: %s' % valid_file) log.info('test_file: %s' % test_file) if not os.path.exists(train_file) or not os.path.exists(valid_file) or not os.path.exists(test_file): dataset_dir = os.path.dirname(train_file) if not os.path.exists(dataset_dir): os.makedirs(dataset_dir) watch.start('create dataset') # FIXME: out of memory (1M sentences) log.info('create dataset...') data_files = (('train', KO_WIKIPEDIA_ORG_TRAIN_SENTENCES_FILE, n_train, train_file, False), ('valid', KO_WIKIPEDIA_ORG_VALID_SENTENCES_FILE, n_valid, valid_file, False), ('test', KO_WIKIPEDIA_ORG_TEST_SENTENCES_FILE, n_test, test_file, False)) for (name, data_file, total, dataset_file, to_one_hot_vector) in data_files: check_interval = 10000 log.info('check_interval: %s' % check_interval) log.info('%s %s total: %s' % (name, os.path.basename(data_file), NumUtil.comma_str(total))) log.info('noise_rate: %s' % noise_rate) features, labels = [], [] with gzip.open(data_file, 'rt') as f: for i, line in enumerate(f, 1): if total < i: break if i % check_interval == 0: time.sleep(0.01) # prevent cpu overload percent = i / total * 100 log.info('create dataset... %.1f%% readed. data len: %s. %s' % (percent, NumUtil.comma_str(len(features)), data_file)) sentence = line.strip() for start in range(0, len(sentence) - window_size + 1): # 문자 단위로 노이즈(공백) 생성 chars = sentence[start: start + window_size] for idx in range(len(chars)): noised_chars = StringUtil.replace_with_index(chars, ' ', idx) features.append(noised_chars) labels.append(chars) log.debug('create dataset... %s "%s" -> "%s"' % (name, noised_chars, chars)) # log.info('noise_sampling: %s' % noise_sampling) # for nth_sample in range(noise_sampling): # 초성, 중성, 종성 단위로 노이즈 생성 # for start in range(0, len(sentence) - window_size + 1): # chars = sentence[start: start + window_size] # noised_chars = SpellingErrorCorrection.encode_noise(chars, noise_rate=noise_rate, noise_with_blank=True) # if chars == noised_chars: # continue # if i % check_interval == 0 and nth_sample == 0: # log.info('create dataset... %s "%s" -> "%s"' % (name, noised_chars, chars)) # features.append(noised_chars) # labels.append(chars) # print('dataset features:', features) # print('dataset labels:', labels) dataset = DataSet(features=features, labels=labels, features_vector=features_vector, labels_vector=labels_vector, name=name) log.info('dataset save... %s' % dataset_file) dataset.save(dataset_file, gzip_format=True, verbose=True) log.info('dataset save OK. %s' % dataset_file) log.info('dataset: %s' % dataset) log.info('create dataset OK.') log.info('') watch.stop('create dataset') watch.start('dataset load') log.info('dataset load...') train = DataSet.load(train_file, gzip_format=True, verbose=True) if n_train >= int('100,000'.replace(',', '')): valid = DataSet.load(valid_file, gzip_format=True, verbose=True) else: valid = DataSet.load(train_file, gzip_format=True, verbose=True) log.info('valid.convert_to_one_hot_vector()...') valid = valid.convert_to_one_hot_vector(verbose=True) log.info('valid.convert_to_one_hot_vector() OK.') log.info('train dataset: %s' % train) log.info('valid dataset: %s' % valid) log.info('dataset load OK.') log.info('') watch.stop('dataset load') X, Y, dropout_keep_prob, train_step, cost, y_hat, accuracy = SpellingErrorCorrection.build_DAE(n_features, window_size, noise_rate, n_hidden1, learning_rate, watch) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) check_interval = max(1, min(1000, n_train // 10)) nth_train, nth_input, total_input = 0, 0, total_epoch * train.size log.info('') log.info('learn...') log.info('total_epoch: %s' % total_epoch) log.info('train.size (total features): %s' % NumUtil.comma_str(train.size)) log.info('check_interval: %s' % check_interval) log.info('total_epoch: %s' % total_epoch) log.info('batch_size: %s' % batch_size) log.info('total_input: %s (total_epoch * train.size)' % total_input) log.info('') watch.start('learn') valid_cost = sys.float_info.max for epoch in range(1, total_epoch + 1): if valid_cost < early_stop_cost: log.info('valid_cost: %s, early_stop_cost: %s, early stopped.' % (valid_cost, early_stop_cost)) break for step, (features_batch, labels_batch) in enumerate(train.next_batch(batch_size=batch_size, to_one_hot_vector=True), 1): if valid_cost < early_stop_cost: break nth_train += 1 nth_input += features_batch.shape[0] sess.run(train_step, feed_dict={X: features_batch, Y: labels_batch, dropout_keep_prob: dropout_keep_rate}) # if nth_train % check_interval == 1: percent = nth_input / total_input * 100 valid_cost = sess.run(cost, feed_dict={X: valid.features, Y: valid.labels, dropout_keep_prob: 1.0}) log.info('[epoch=%s][%.1f%%] %s cost: %.8f' % (epoch, percent, valid.name, valid_cost)) watch.stop('learn') log.info('learn OK.') log.info('') log.info('model save... %s' % model_file) watch.start('model save...') model_dir = os.path.dirname(model_file) if not os.path.exists(model_dir): os.makedirs(model_dir) saver = tf.train.Saver() saver.save(sess, model_file) watch.stop('model save...') log.info('model save OK. %s' % model_file) log.info('') log.info('total_epoch: %s' % total_epoch) log.info('batch_size: %s' % batch_size) log.info('total_input: %s (total_epoch * train.size)' % total_input) log.info('') log.info(watch.summary()) log.info('')
def dump_corpus(mongo_url, db_name, collection_name, sentences_file, characters_file, info_file, urls_file, train_sentences_file, valid_sentences_file, test_sentences_file, mongo_query=None, limit=None): """ Mongodb에서 문서를 읽어서, 문장 단위로 저장한다. (단 문장안의 단어가 1개 이거나, 한글이 전혀 없는 문장은 추출하지 않는다.) :param characters_file: :param urls_file: :param info_file: :param mongo_url: mongodb://~~~ :param db_name: database name of mongodb :param collection_name: collection name of mongodb :param sentences_file: *.sentence file :param train_sentences_file: :param valid_sentences_file: :param test_sentences_file: :param mongo_query: default={} :param limit: :return: """ if mongo_query is None: mongo_query = {} corpus_mongo = MongodbUtil(mongo_url, db_name=db_name, collection_name=collection_name) total_docs = corpus_mongo.count() log.info('%s total: %s' % (corpus_mongo, NumUtil.comma_str(total_docs))) output_dir = os.path.basename(sentences_file) if not os.path.exists(output_dir): os.makedirs(output_dir) with gzip.open(sentences_file, 'wt') as out_f, \ gzip.open(train_sentences_file, 'wt') as train_f, \ gzip.open(valid_sentences_file, 'wt') as valid_f, \ gzip.open(test_sentences_file, 'wt') as test_f, \ open(info_file, 'wt') as info_f, \ open(urls_file, 'wt') as urls_f: char_set = set() n_docs = n_total = n_train = n_valid = n_test = 0 if limit: cursor = corpus_mongo.find(mongo_query, limit=limit) else: cursor = corpus_mongo.find(mongo_query) for i, row in enumerate(cursor, 1): if i % 1000 == 0: log.info('%s %.1f%% writed.' % (os.path.basename(sentences_file), i / total_docs * 100)) sentences = [] for c in row['content']: sentences.extend(HangulUtil.text2sentences(c['sentences'], remove_only_one_word=True, has_hangul=True)) # sentences = HangulUtil.text2sentences(row['content'], remove_only_one_word=True, has_hangul=True) log.debug('url: %s, len: %s' % (row['url'], len(sentences))) if len(sentences) == 0: # log.error(row['content']) continue urls_f.write(row['url']) urls_f.write('\n') n_docs += 1 for s in sentences: _char_set = set([c for c in s]) char_set.update(_char_set) n_total += 1 out_f.write(s) out_f.write('\n') if len(sentences) >= 10: # can split test_len = valid_len = len(sentences) // 10 # log.info('train: %s, test: %s, valid: %s' % (len(sentences) - test_len - valid_len, test_len, valid_len)) for s in sentences[:test_len]: n_test += 1 test_f.write(s) test_f.write('\n') for s in sentences[test_len:test_len + valid_len]: n_valid += 1 valid_f.write(s) valid_f.write('\n') for s in sentences[test_len + valid_len:]: n_train += 1 train_f.write(s) train_f.write('\n') else: # can't split for s in sentences: n_train += 1 train_f.write(s) train_f.write('\n') char_list = list(char_set) char_list.sort() log.info('writed to %s...' % characters_file) with open(characters_file, 'w') as f: for c in char_list: f.write(c) f.write('\n') log.info('writed to %s OK.' % characters_file) log.info('total docs: %s', NumUtil.comma_str(total_docs)) log.info('total docs: %s (has hangul sentence)', NumUtil.comma_str(n_docs)) log.info('total sentences: %s (has hangul sentence)', NumUtil.comma_str(n_total)) log.info('train: %s', NumUtil.comma_str(n_train)) log.info('valid: %s', NumUtil.comma_str(n_valid)) log.info('test: %s', NumUtil.comma_str(n_test)) log.info('total characters: %s', NumUtil.comma_str(len(char_list))) info_f.write('total docs: %s\n' % NumUtil.comma_str(total_docs)) info_f.write('total docs: %s (has hangul sentence)\n' % NumUtil.comma_str(n_docs)) info_f.write('total sentences: %s (has hangul sentence)\n' % NumUtil.comma_str(n_total)) info_f.write('train: %s\n' % NumUtil.comma_str(n_train)) info_f.write('valid: %s\n' % NumUtil.comma_str(n_valid)) info_f.write('test: %s\n' % NumUtil.comma_str(n_test)) info_f.write('total characters: %s\n' % NumUtil.comma_str(len(char_list)))
watch = WatchUtil() watch.start() _test_cost, _W1, _b1, _x_batch, _y_batch, _y_hat_batch = sess.run( [cost, W1, b1, x, y, y_hat], feed_dict={ learning_rate: _learning_rate, use_first_pipeline: True }) log.info('') log.info('W1: %s' % ['%.4f' % i for i in _W1]) log.info('b1: %.4f' % _b1) for (x1, x2), _y, _y_hat in zip( _x_batch, _y_batch, _y_hat_batch): log.debug( '%3d + %3d = %4d (y_hat: %4.1f)' % (x1, x2, _y, _y_hat)) log.info('') log.info( '"%s" test: test_cost: %.8f, %.2f secs (batch_size: %s)' % (model_name, _test_cost, watch.elapsed(), batch_size)) log.info('') except: log.info(traceback.format_exc()) finally: coordinator.request_stop() coordinator.join( threads) # Wait for threads to finish.