Example #1
0
def main():
	# img_width, img_height = 48, 48
	img_width, img_height = 200, 60
	img_channels = 1 
	# batch_size = 1024
	batch_size = 32
	nb_epoch = 1000
	post_correction = False

	save_dir = 'save_model/' + str(datetime.now()).split('.')[0].split()[0] + '/' # model is saved corresponding to the datetime
	train_data_dir = 'train_data/ip_train/'
	# train_data_dir = 'train_data/single_1000000/'
	val_data_dir = 'train_data/ip_val/'
	test_data_dir = 'test_data//'
	weights_file_path = 'save_model/2016-10-27/weights.11-1.58.hdf5'
	char_set, char2idx = get_char_set(train_data_dir)
	nb_classes = len(char_set)
	max_nb_char = get_maxnb_char(train_data_dir)
	label_set = get_label_set(train_data_dir)
	# val 'char_set:', char_set
	print 'nb_classes:', nb_classes
	print 'max_nb_char:', max_nb_char
	print 'size_label_set:', len(label_set)
	model = build_shallow(img_channels, img_width, img_height, max_nb_char, nb_classes) # build CNN architecture
	# model.load_weights(weights_file_path) # load trained model

	val_data = load_data(val_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
	# val_data = None 
	train_data = load_data(train_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx) 
	train(model, batch_size, nb_epoch, save_dir, train_data, val_data, char_set)
Example #2
0
def main():
    # img_width, img_height = 48, 48
    img_width, img_height = 200, 60
    img_channels = 1
    # batch_size = 1024
    batch_size = 32
    nb_epoch = 1000
    post_correction = False

    save_dir = 'save_model/' + str(datetime.now()).split('.')[0].split(
    )[0] + '/'  # model is saved corresponding to the datetime
    train_data_dir = 'train_data/ip_train/'
    # train_data_dir = 'train_data/single_1000000/'
    val_data_dir = 'train_data/ip_val/'
    test_data_dir = 'test_data//'
    weights_file_path = 'save_model/2016-10-27/weights.11-1.58.hdf5'
    char_set, char2idx = get_char_set(train_data_dir)
    nb_classes = len(char_set)
    max_nb_char = get_maxnb_char(train_data_dir)
    label_set = get_label_set(train_data_dir)
    # val 'char_set:', char_set
    print 'nb_classes:', nb_classes
    print 'max_nb_char:', max_nb_char
    print 'size_label_set:', len(label_set)
    model = build_shallow(img_channels, img_width, img_height, max_nb_char,
                          nb_classes)  # build CNN architecture
    # model.load_weights(weights_file_path) # load trained model

    val_data = load_data(val_data_dir, max_nb_char, img_width, img_height,
                         img_channels, char_set, char2idx)
    # val_data = None
    train_data = load_data(train_data_dir, max_nb_char, img_width, img_height,
                           img_channels, char_set, char2idx)
    train(model, batch_size, nb_epoch, save_dir, train_data, val_data,
          char_set)
Example #3
0
def main():
    img_width, img_height = 64, 64
    img_channels = 1
    batch_size = 32
    nb_epoch = 4
    post_correction = False

    save_dir = 'models/models/' + str(datetime.now()).split('.')[0].split(
    )[0] + '/'  # model is saved corresponding to the datetime
    train_data_dir = './data/train/'
    val_data_dir = './data/val/'
    char_set, char2idx = get_char_set(
        './data/'
    )  #charset = ['empty',...,'鸵', '豸', '山',...] char2idx = { ...,'弗': 3290, '毓': 6488,...}
    nb_classes = len(char_set)
    print('nb_classes:', nb_classes)  #607

    print("===========Building Model:===============")
    model = build_shallow(img_channels, img_width, img_height,
                          nb_classes)  # build CNN architecture

    print("=========Begin Loading Val Data:=============\n")
    val_data = load_data(val_data_dir, img_width, img_height, img_channels,
                         char_set)

    print("=========Begin Loading Train Data:=============\n")
    train_data = load_data(train_data_dir, img_width, img_height, img_channels,
                           char_set)

    print("===========Begin Training=============:\n")
    train(model, batch_size, nb_epoch, save_dir, train_data, val_data,
          char_set)
Example #4
0
#-*-coding:utf8-*-

import os
import time
import numpy as np
from datetime import datetime
from keras.callbacks import ModelCheckpoint
from util import one_hot_decoder, plot_loss_figure, load_data, get_char_set,load_img
from util import  list2str
from post_correction import  correction
from models.shallow import build_shallow
import pdb


char_set, char2idx = get_char_set('./data/')
nb_classes = len(char_set)
print ('nb_classes:', nb_classes)#607


def pred(model, X, char_set, post_correction):
	pred_res = model.predict(X)
	pred_res = [one_hot_decoder(i, char_set) for i in pred_res]
	pred_res = [list2str(i) for i in pred_res]
	# post correction
	if post_correction:
		pred_res = correction(pred_res, char_set)
	return pred_res

def infer(path):
	img_width, img_height = 64, 64
	img_channels = 1
Example #5
0
 def __init__(self):
     self.char_set = get_char_set(self.train_data_dir)[0]
     self.nb_classes = len(self.char_set)
     self.max_nb_char = get_maxnb_char(self.train_data_dir)
     self.label_set = get_label_set(self.train_data_dir)
     self.pred_probs = None
Example #6
0
 def __init__(self):
     self.char_set = get_char_set(self.train_data_dir)[0]
     self.nb_classes = len(self.char_set)
     self.max_nb_char = get_maxnb_char(self.train_data_dir)
     self.label_set = get_label_set(self.train_data_dir)
     self.pred_probs = None