예제 #1
0
    with open('vocab.json', 'w') as fp:
        json.dump(vocab, fp)

    with open('config.txt', 'w') as f:
        f.write(str(vocab_size) + '\n')
        f.write(str(max_length))

    # import model
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    model = CNN(sess=sess,
                vocab_size=vocab_size,
                sequence_length=max_length,
                trainable=True)
    model.embedding_assign(embedding)
    batches = batch_iter(list(zip(x_input, y_input)),
                         batch_size=64,
                         num_epochs=5)
    saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=0.5)

    # train model
    print('모델 훈련을 시작합니다.')
    avgLoss = []
    for step, batch in enumerate(batches):
        x_train, y_train = zip(*batch)
        x_train = sentence_to_index_morphs(x_train, vocab, max_length)
        l, _ = model.train(x_train, y_train)
        avgLoss.append(l)
        if step % 500 == 0:
            print('batch:', '%04d' % step, 'loss:', '%05f' % np.mean(avgLoss))
예제 #2
0
    # save configuration
    with open('config.txt', 'w') as f:
        f.write(str(vocab_size) + '\n')
        f.write(str(max_length))

    # open session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # make model instance
    model = CNN(sess=sess, vocab_size=vocab_size, sequence_length=max_length, trainable=True)

    # assign pretrained embedding vectors
    model.embedding_assign(embedding)

    # make train batches
    batches = batch_iter(list(zip(x_input, y_input)), batch_size=64, num_epochs=5)

    # model saver
    saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=0.5)

    # train model
    print('모델 훈련을 시작합니다.')
    avgLoss = []
    for step, batch in enumerate(batches):
        x_train, y_train = zip(*batch)
        x_train = sentence_to_index_morphs(x_train, vocab, max_length)
        l, _ = model.train(x_train, y_train)
        avgLoss.append(l)