示例#1
0
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

from bage_utils.base_util import is_server
from nlp4kor_tensorflow.config import MNIST_DATA_DIR, MNIST_DAE_MODEL_DIR, log

if __name__ == '__main__':
    mnist_data = os.path.join(MNIST_DATA_DIR, MNIST_DAE_MODEL_DIR)  # input
    device2use = '/gpu:0' if is_server() else '/cpu:0'

    model_file = os.path.join(MNIST_DAE_MODEL_DIR,
                              'dae_mnist_model≤/model')  # .%s' % max_sentences
    log.info('model_file: %s' % model_file)

    model_dir = os.path.dirname(model_file)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    image_shape = (28, 28)
    mnist = input_data.read_data_sets(mnist_data, one_hot=True)
    assert (mnist.train.images.shape[1] == mnist.test.images.shape[1])
    n_input_dim = mnist.train.images.shape[
        1]  # MNIST data input (img shape: 28*28)
    n_output_dim = n_input_dim  # MNIST data input (img shape: 28*28)
    n_hidden_1 = 256  # 1st layer num features
    n_hidden_2 = 256  # 2nd layer num features
示例#2
0
import logging
import os
import sys
import warnings

from bage_utils.base_util import is_server, db_hostname, is_pycharm_remote
from bage_utils.log_util import LogUtil

warnings.simplefilter(action='ignore', category=FutureWarning)  # ignore future warnings

log = None
if log is None:
    if is_server():  # by shell script
        log = LogUtil.get_logger(sys.argv[0], level=logging.INFO, console_mode=True, multiprocess=False)  # global log
    elif is_pycharm_remote():  # remote gpu server
        log = LogUtil.get_logger(sys.argv[0], level=logging.DEBUG, console_mode=True, multiprocess=False)  # global log # console_mode=True for jupyter
    else:  # my macbook
        log = LogUtil.get_logger(None, level=logging.DEBUG, console_mode=True)  # global log

#################################################
# DB
#################################################
MONGO_URL = r'mongodb://%s:%s@%s:%s/%s?authMechanism=MONGODB-CR' % ('root', os.getenv('MONGODB_PASSWD'), 'db-local', '27017', 'admin')
MYSQL_URL = {'host': db_hostname(), 'user': '******', 'passwd': os.getenv('MYSQL_PASSWD'), 'db': 'kr_nlp'}

#################################################
# tensorboard log dir
#################################################
TENSORBOARD_LOG_DIR = os.path.join(os.getenv("HOME"), 'tensorboard_log')
# log.info('TENSORBOARD_LOG_DIR: %s' % TENSORBOARD_LOG_DIR)
if not os.path.exists(TENSORBOARD_LOG_DIR):
示例#3
0
        corpus = Word2VecCorpus.load(filepath=args.corpus_file)
        log.info(
            f'load {args.corpus_file} OK. (elapsed: {watch.elapsed_string()})')
        log.info(corpus.vocab)

        if len(corpus.vocab) > 1e5:  # out of memory (11GB GPU memory)
            args.device_no = None

        log.info('')
        log.info(args)
        log.info('')

        embedding_file = Word2VecEmbedding.get_filenpath(args)
        if os.path.exists(embedding_file):
            log.info(f'embedding_file: {embedding_file} exists. skipped')
            if is_server():
                SlackUtil.send_message(
                    f'embedding_file: {embedding_file} exists. skipped')
                exit()

        log.info('')

        log.info(f'Word2VecTrainer() ...')
        watch.start()
        trainer = Word2VecTrainer(vocab=corpus.vocab,
                                  corpus=corpus,
                                  batch=args.batch,
                                  device_no=args.device_no,
                                  embed=args.embed,
                                  neg_sample=args.neg_sample,
                                  neg_weight=args.neg_weight,
示例#4
0
    log.info('')

    characters_file = KO_WIKIPEDIA_ORG_CHARACTERS_FILE
    log.info('characters_file: %s' % characters_file)
    try:
        if len(sys.argv) == 4:
            n_train = int(sys.argv[1])
            window_size = int(sys.argv[2])
            noise_rate = float(sys.argv[3])
        else:
            n_train, noise_rate, window_size = None, None, None

        if n_train is None or n_train == 0:
            n_train = int('1,000,000'.replace(',', ''))

        if is_server():  # batch server
            n_train = 10000
            n_valid = min(10, n_train)
            n_test = min(10, n_train)
        else:  # for demo
            n_train = n_valid = n_test = 3

        if noise_rate is None or window_size is None:
            window_size = 10  # 2 ~ 10 # feature로 추출할 문자 수 (label과 동일)
            noise_rate = max(
                0.1, 1 / window_size
            )  # 0.0 ~ 1.0 # noise_rate = 노이즈 문자 수 / 전체 문자 수 (windos 안에서 최소 한 글자는 노이즈가 생기도록 함.)

        dropout_keep_rate = 1.0  # 0.0 ~ 1.0 # one hot vector에 경우에 dropout 사용시, 학습이 안 됨.
        noise_sampling = 100  # 한 입력에 대하여 몇 개의 노이즈 샘플을 생성할지. blank 방식(문자 단위)으로 noise 생성할 때는 사용 안함.