コード例 #1
0
 def get_model(self):
     if self.model_type == 'fcn':
         self.input_length = 29 * 16000
         return Model.FCN()
     elif self.model_type == 'musicnn':
         self.input_length = 3 * 16000
         return Model.Musicnn(dataset=self.dataset)
     elif self.model_type == 'crnn':
         self.input_length = 29 * 16000
         return Model.CRNN()
     elif self.model_type == 'sample':
         self.input_length = 59049
         return Model.SampleCNN()
     elif self.model_type == 'se':
         self.input_length = 59049
         return Model.SampleCNNSE()
     elif self.model_type == 'short':
         self.input_length = 59049
         return Model.ShortChunkCNN()
     elif self.model_type == 'short_res':
         self.input_length = 59049
         return Model.ShortChunkCNN_Res()
     elif self.model_type == 'attention':
         self.input_length = 15 * 16000
         return Model.CNNSA()
     elif self.model_type == 'hcnn':
         self.input_length = 5 * 16000
         return Model.HarmonicCNN()
     else:
         print(
             'model_type has to be one of [fcn, musicnn, crnn, sample, se, short, short_res, attention]'
         )
コード例 #2
0
 def get_model(self):
     if self.model_type == 'fcn':
         return Model.FCN()
     elif self.model_type == 'musicnn':
         return Model.Musicnn(dataset=self.dataset)
     elif self.model_type == 'crnn':
         return Model.CRNN()
     elif self.model_type == 'sample':
         return Model.SampleCNN()
     elif self.model_type == 'se':
         return Model.SampleCNNSE()
     elif self.model_type == 'short':
         return Model.ShortChunkCNN()
     elif self.model_type == 'short_res':
         return Model.ShortChunkCNN_Res()
     elif self.model_type == 'attention':
         return Model.CNNSA()
     elif self.model_type == 'hcnn':
         return Model.HarmonicCNN()
コード例 #3
0
def main():

    # mode argument
    args = argparse.ArgumentParser()
    args.add_argument(
        "--letter",
        type=str,
        default=
        " ,.()\'\"?!01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ가각간갇갈갉갊감갑값갓갔강갖갗같갚갛개객갠갤갬갭갯갰갱갸갹갼걀걋걍걔걘걜거걱건걷걸걺검겁것겄겅겆겉겊겋게겐겔겜겝겟겠겡겨격겪견겯결겸겹겻겼경곁계곈곌곕곗고곡곤곧골곪곬곯곰곱곳공곶과곽관괄괆괌괍괏광괘괜괠괩괬괭괴괵괸괼굄굅굇굉교굔굘굡굣구국군굳굴굵굶굻굼굽굿궁궂궈궉권궐궜궝궤궷귀귁귄귈귐귑귓규균귤그극근귿글긁금급긋긍긔기긱긴긷길긺김깁깃깅깆깊까깍깎깐깔깖깜깝깟깠깡깥깨깩깬깰깸깹깻깼깽꺄꺅꺌꺼꺽꺾껀껄껌껍껏껐껑께껙껜껨껫껭껴껸껼꼇꼈꼍꼐꼬꼭꼰꼲꼴꼼꼽꼿꽁꽂꽃꽈꽉꽐꽜꽝꽤꽥꽹꾀꾄꾈꾐꾑꾕꾜꾸꾹꾼꿀꿇꿈꿉꿋꿍꿎꿔꿜꿨꿩꿰꿱꿴꿸뀀뀁뀄뀌뀐뀔뀜뀝뀨끄끅끈끊끌끎끓끔끕끗끙끝끼끽낀낄낌낍낏낑나낙낚난낟날낡낢남납낫났낭낮낯낱낳내낵낸낼냄냅냇냈냉냐냑냔냘냠냥너넉넋넌널넒넓넘넙넛넜넝넣네넥넨넬넴넵넷넸넹녀녁년녈념녑녔녕녘녜녠노녹논놀놂놈놉놋농높놓놔놘놜놨뇌뇐뇔뇜뇝뇟뇨뇩뇬뇰뇹뇻뇽누눅눈눋눌눔눕눗눙눠눴눼뉘뉜뉠뉨뉩뉴뉵뉼늄늅늉느늑는늘늙늚늠늡늣능늦늪늬늰늴니닉닌닐닒님닙닛닝닢다닥닦단닫달닭닮닯닳담답닷닸당닺닻닿대댁댄댈댐댑댓댔댕댜더덕덖던덛덜덞덟덤덥덧덩덫덮데덱덴델뎀뎁뎃뎄뎅뎌뎐뎔뎠뎡뎨뎬도독돈돋돌돎돐돔돕돗동돛돝돠돤돨돼됐되된될됨됩됫됴두둑둔둘둠둡둣둥둬뒀뒈뒝뒤뒨뒬뒵뒷뒹듀듄듈듐듕드득든듣들듦듬듭듯등듸디딕딘딛딜딤딥딧딨딩딪따딱딴딸땀땁땃땄땅땋때땍땐땔땜땝땟땠땡떠떡떤떨떪떫떰떱떳떴떵떻떼떽뗀뗄뗌뗍뗏뗐뗑뗘뗬또똑똔똘똥똬똴뙈뙤뙨뚜뚝뚠뚤뚫뚬뚱뛔뛰뛴뛸뜀뜁뜅뜨뜩뜬뜯뜰뜸뜹뜻띄띈띌띔띕띠띤띨띰띱띳띵라락란랄람랍랏랐랑랒랖랗래랙랜랠램랩랫랬랭랴략랸럇량러럭런럴럼럽럿렀렁렇레렉렌렐렘렙렛렝려력련렬렴렵렷렸령례롄롑롓로록론롤롬롭롯롱롸롼뢍뢨뢰뢴뢸룀룁룃룅료룐룔룝룟룡루룩룬룰룸룹룻룽뤄뤘뤠뤼뤽륀륄륌륏륑류륙륜률륨륩륫륭르륵른를름릅릇릉릊릍릎리릭린릴림립릿링마막만많맏말맑맒맘맙맛망맞맡맣매맥맨맬맴맵맷맸맹맺먀먁먈먕머먹먼멀멂멈멉멋멍멎멓메멕멘멜멤멥멧멨멩며멱면멸몃몄명몇몌모목몫몬몰몲몸몹못몽뫄뫈뫘뫙뫼묀묄묍묏묑묘묜묠묩묫무묵묶문묻물묽묾뭄뭅뭇뭉뭍뭏뭐뭔뭘뭡뭣뭬뮈뮌뮐뮤뮨뮬뮴뮷므믄믈믐믓미믹민믿밀밂밈밉밋밌밍및밑바박밖밗반받발밝밞밟밤밥밧방밭배백밴밸뱀뱁뱃뱄뱅뱉뱌뱍뱐뱝버벅번벋벌벎범법벗벙벚베벡벤벧벨벰벱벳벴벵벼벽변별볍볏볐병볕볘볜보복볶본볼봄봅봇봉봐봔봤봬뵀뵈뵉뵌뵐뵘뵙뵤뵨부북분붇불붉붊붐붑붓붕붙붚붜붤붰붸뷔뷕뷘뷜뷩뷰뷴뷸븀븃븅브븍븐블븜븝븟비빅빈빌빎빔빕빗빙빚빛빠빡빤빨빪빰빱빳빴빵빻빼빽뺀뺄뺌뺍뺏뺐뺑뺘뺙뺨뻐뻑뻔뻗뻘뻠뻣뻤뻥뻬뼁뼈뼉뼘뼙뼛뼜뼝뽀뽁뽄뽈뽐뽑뽕뾔뾰뿅뿌뿍뿐뿔뿜뿟뿡쀼쁑쁘쁜쁠쁨쁩삐삑삔삘삠삡삣삥사삭삯산삳살삵삶삼삽삿샀상샅새색샌샐샘샙샛샜생샤샥샨샬샴샵샷샹섀섄섈섐섕서석섞섟선섣설섦섧섬섭섯섰성섶세섹센셀셈셉셋셌셍셔셕션셜셤셥셧셨셩셰셴셸솅소속솎손솔솖솜솝솟송솥솨솩솬솰솽쇄쇈쇌쇔쇗쇘쇠쇤쇨쇰쇱쇳쇼쇽숀숄숌숍숏숑수숙순숟술숨숩숫숭숯숱숲숴쉈쉐쉑쉔쉘쉠쉥쉬쉭쉰쉴쉼쉽쉿슁슈슉슐슘슛슝스슥슨슬슭슴습슷승시식신싣실싫심십싯싱싶싸싹싻싼쌀쌈쌉쌌쌍쌓쌔쌕쌘쌜쌤쌥쌨쌩썅써썩썬썰썲썸썹썼썽쎄쎈쎌쏀쏘쏙쏜쏟쏠쏢쏨쏩쏭쏴쏵쏸쐈쐐쐤쐬쐰쐴쐼쐽쑈쑤쑥쑨쑬쑴쑵쑹쒀쒔쒜쒸쒼쓩쓰쓱쓴쓸쓺쓿씀씁씌씐씔씜씨씩씬씰씸씹씻씽아악안앉않알앍앎앓암압앗았앙앝앞애액앤앨앰앱앳앴앵야약얀얄얇얌얍얏양얕얗얘얜얠얩어억언얹얻얼얽얾엄업없엇었엉엊엌엎에엑엔엘엠엡엣엥여역엮연열엶엷염엽엾엿였영옅옆옇예옌옐옘옙옛옜오옥온올옭옮옰옳옴옵옷옹옻와왁완왈왐왑왓왔왕왜왝왠왬왯왱외왹왼욀욈욉욋욍요욕욘욜욤욥욧용우욱운울욹욺움웁웃웅워웍원월웜웝웠웡웨웩웬웰웸웹웽위윅윈윌윔윕윗윙유육윤율윰윱윳융윷으윽은을읊음읍읏응읒읓읔읕읖읗의읩읜읠읨읫이익인일읽읾잃임입잇있잉잊잎자작잔잖잗잘잚잠잡잣잤장잦재잭잰잴잼잽잿쟀쟁쟈쟉쟌쟎쟐쟘쟝쟤쟨쟬저적전절젊점접젓정젖제젝젠젤젬젭젯젱져젼졀졈졉졌졍졔조족존졸졺좀좁좃종좆좇좋좌좍좔좝좟좡좨좼좽죄죈죌죔죕죗죙죠죡죤죵주죽준줄줅줆줌줍줏중줘줬줴쥐쥑쥔쥘쥠쥡쥣쥬쥰쥴쥼즈즉즌즐즘즙즛증지직진짇질짊짐집짓징짖짙짚짜짝짠짢짤짧짬짭짯짰짱째짹짼쨀쨈쨉쨋쨌쨍쨔쨘쨩쩌쩍쩐쩔쩜쩝쩟쩠쩡쩨쩽쪄쪘쪼쪽쫀쫄쫌쫍쫏쫑쫓쫘쫙쫠쫬쫴쬈쬐쬔쬘쬠쬡쭁쭈쭉쭌쭐쭘쭙쭝쭤쭸쭹쮜쮸쯔쯤쯧쯩찌찍찐찔찜찝찡찢찧차착찬찮찰참찹찻찼창찾채책챈챌챔챕챗챘챙챠챤챦챨챰챵처척천철첨첩첫첬청체첵첸첼쳄쳅쳇쳉쳐쳔쳤쳬쳰촁초촉촌촐촘촙촛총촤촨촬촹최쵠쵤쵬쵭쵯쵱쵸춈추축춘출춤춥춧충춰췄췌췐취췬췰췸췹췻췽츄츈츌츔츙츠측츤츨츰츱츳층치칙친칟칠칡침칩칫칭카칵칸칼캄캅캇캉캐캑캔캘캠캡캣캤캥캬캭컁커컥컨컫컬컴컵컷컸컹케켁켄켈켐켑켓켕켜켠켤켬켭켯켰켱켸코콕콘콜콤콥콧콩콰콱콴콸쾀쾅쾌쾡쾨쾰쿄쿠쿡쿤쿨쿰쿱쿳쿵쿼퀀퀄퀑퀘퀭퀴퀵퀸퀼큄큅큇큉큐큔큘큠크큭큰클큼큽킁키킥킨킬킴킵킷킹타탁탄탈탉탐탑탓탔탕태택탠탤탬탭탯탰탱탸턍터턱턴털턺텀텁텃텄텅테텍텐텔템텝텟텡텨텬텼톄톈토톡톤톨톰톱톳통톺톼퇀퇘퇴퇸툇툉툐투툭툰툴툼툽툿퉁퉈퉜퉤튀튁튄튈튐튑튕튜튠튤튬튱트특튼튿틀틂틈틉틋틔틘틜틤틥티틱틴틸팀팁팃팅파팍팎판팔팖팜팝팟팠팡팥패팩팬팰팸팹팻팼팽퍄퍅퍼퍽펀펄펌펍펏펐펑페펙펜펠펨펩펫펭펴편펼폄폅폈평폐폘폡폣포폭폰폴폼폽폿퐁퐈퐝푀푄표푠푤푭푯푸푹푼푿풀풂품풉풋풍풔풩퓌퓐퓔퓜퓟퓨퓬퓰퓸퓻퓽프픈플픔픕픗피픽핀필핌핍핏핑하학한할핥함합핫항해핵핸핼햄햅햇했행햐향허헉헌헐헒험헙헛헝헤헥헨헬헴헵헷헹혀혁현혈혐협혓혔형혜혠혤혭호혹혼홀홅홈홉홋홍홑화확환활홧황홰홱홴횃횅회획횐횔횝횟횡효횬횰횹횻후훅훈훌훑훔훗훙훠훤훨훰훵훼훽휀휄휑휘휙휜휠휨휩휫휭휴휵휸휼흄흇흉흐흑흔흖흗흘흙흠흡흣흥흩희흰흴흼흽힁히힉힌힐힘힙힛힝"
    )
    args.add_argument("--lr", type=float, default=0.0001)
    args.add_argument("--cuda", type=bool, default=True)
    args.add_argument("--num_epochs", type=int, default=50000)
    args.add_argument("--model_name", type=str, default="55_10")
    args.add_argument("--batch", type=int, default=2)
    args.add_argument("--mode", type=str, default="test")
    args.add_argument("--prediction_dir", type=str, default="prediction")
    args.add_argument("--print_iter", type=int, default=10)

    config = args.parse_args()

    letter = config.letter
    lr = config.lr
    cuda = config.cuda
    num_epochs = config.num_epochs
    model_name = config.model_name
    batch = config.batch
    mode = config.mode
    prediction_dir = config.prediction_dir
    print_iter = config.print_iter
    imgH = 32
    imgW = 200
    nclass = len(letter) + 1
    nc = 1

    new_model = model.CRNN(imgH, nc, nclass, 256)
    new_model.apply(model.weights_init)
    device = torch.device('cuda') if cuda else torch.device('cpu')

    converter = dataloader.strLabelConverter(letter)

    images = torch.FloatTensor(batch, 1, imgH, imgW)
    texts = torch.IntTensor(batch * 1000)
    lengths = torch.IntTensor(batch)

    images = Variable(images)
    texts = Variable(texts)
    lengths = Variable(lengths)

    #check parameter of model
    print("------------------------------------------------------------")
    total_params = sum(p.numel() for p in new_model.parameters())
    print("num of parameter : ", total_params)
    trainable_params = sum(p.numel() for p in new_model.parameters()
                           if p.requires_grad)
    print("num of trainable_ parameter :", trainable_params)
    print("------------------------------------------------------------")

    if mode == 'train':
        print('trian start')
        train_loader = data_loader(DATASET_PATH,
                                   batch,
                                   imgH,
                                   imgW,
                                   phase='train')
        val_loader = data_loader(DATASET_PATH, batch, imgH, imgW, phase='val')
        params = [p for p in new_model.parameters() if p.requires_grad]
        optimizer = optim.Adam(params, lr=lr, betas=(0.5, 0.999))
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                 step_size=40,
                                                 gamma=0.1)
        train(num_epochs, new_model, device, train_loader, val_loader, images,
              texts, lengths, converter, optimizer, lr_scheduler,
              prediction_dir, print_iter)

    elif mode == 'test':
        print('test start')
        test_loader = data_loader(DATASET_PATH, 1, imgH, imgW, phase='test')
        load_model(model_name, new_model)
        test(new_model, device, test_loader, images, texts, lengths, converter,
             prediction_dir)
コード例 #4
0
import model

if __name__ == '__main__':
    crnn = model.CRNN()
    # crnn.train("train/")
    crnn.infer("val/")
コード例 #5
0
ファイル: engine.py プロジェクト: akshit61/CRNN_OCR
from tqdm import tqdm
import torch
import config
from torch import nn
import utils
import model
from torch.nn import functional as f
from torch.autograd import Variable

criterion = nn.CTCLoss(blank=0, zero_infinity=True)
converter = utils.strLabelConverter(config.ALPHABETS)
crnn = model.CRNN(config.IMG_HEIGHT, nc=3)

image = torch.FloatTensor(config.BATCH_SIZE, 3, config.IMG_HEIGHT,
                          config.IMG_WIDTH)
text = torch.LongTensor(config.BATCH_SIZE * 5)
length = torch.LongTensor(config.BATCH_SIZE)

if config.DEVICE == 'cuda' and torch.cuda.is_available():
    criterion = criterion.cuda()
    image = image.cuda()
    text = text.cuda()

image = Variable(image)
text = Variable(text)
length = Variable(length)


def train_fn(model, data_loader, optimizer):
    model.train()
    tk = tqdm(data_loader, total=len(data_loader))
コード例 #6
0
ファイル: train.py プロジェクト: zzmcdc/document-ocr
def main():
  train_tf_record = os.path.join(FLAGS.data_dir, 'ocr-train-*.tfrecord')
  eval_tf_record = os.path.join(FLAGS.data_dir, 'ocr-validation-*.tfrecord')

  char_map_dict = load_char_map()
  train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
  model_name = 'crnn_ctc_ocr_{:s}.ckpt'.format(str(train_start_time))
  model_save_path = os.path.join(FLAGS.model_dir, model_name)

  config = Config()
  config.batch_size = FLAGS.batch_size
  config.num_classes = len(char_map_dict) + 1
  train_input_fn = input_fn.input_fn(train_tf_record, FLAGS.batch_size, channel_size=FLAGS.channel_size)

  crnn_model = model.CRNN(config)
  saver = tf.train.Saver()
  if not os.path.exists(FLAGS.model_dir):
    os.makedirs(FLAGS.model_dir)

  global_step = tf.train.get_or_create_global_step()
  learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                             global_step,
                                             FLAGS.decay_steps,
                                             FLAGS.decay_rate,
                                             staircase = True)
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op= tf.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate).minimize(crnn_model.loss, 
            global_step=global_step)
    train_op = tf.group([train_op, update_ops])
  decoded, log_prob = tf.nn.ctc_greedy_decoder(crnn_model.logits, crnn_model.sequence_length)
  pred_str_labels = tf.as_string(decoded[0].values)
  pred_tensor = tf.SparseTensor(indices=decoded[0].indices, values=pred_str_labels, dense_shape=decoded[0].dense_shape)
  true_str_labels = tf.as_string(crnn_model.labels.values)
  true_tensor = tf.SparseTensor(indices=crnn_model.labels.indices, values=true_str_labels, dense_shape=crnn_model.labels.dense_shape)
  edit_distance = tf.reduce_mean(tf.edit_distance(pred_tensor, true_tensor, normalize=True), name='distance')
  tf.summary.scalar(name='edit_distance', tensor= edit_distance)
  tf.summary.scalar(name='ctc_loss', tensor=crnn_model.loss)
  #tf.summary.scalar(name='learning_rate', tensor=learning_rate)
  merge_summary_op = tf.summary.merge_all()
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    summary_writer = tf.summary.FileWriter(FLAGS.model_dir)
    summary_writer.add_graph(sess.graph)
    train_next_batch = train_input_fn.get_next()

    save_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    if save_path:
      saver.restore(sess=sess, save_path=save_path)
      print("restore from %s"%(save_path) )
      st = int(save_path.split("-")[-1])
      sess.run(global_step.assign(st))

    for s in range(FLAGS.max_train_steps):
      batch = sess.run(train_next_batch)
      images = batch['images']
      labels = batch['labels']
      sequence_length = batch['sequence_length']
      _, loss , lr,  summary, step, logits, dis = sess.run(
          [train_op, crnn_model.loss, learning_rate, merge_summary_op, global_step , crnn_model.logits , edit_distance ],
          feed_dict = {
            crnn_model.images:images, 
            crnn_model.labels:labels, 
            crnn_model.sequence_length:sequence_length, 
            crnn_model.keep_prob:0.5, 
            crnn_model.is_training:True})

      print("step: {step} lr: {lr} loss: {loss} acc: {dis} ".format(step=step, lr=lr, loss=loss, dis=(1-dis) ))
      if step % FLAGS.step_per_save == 0:
        summary_writer.add_summary(summary=summary, global_step=step)
        saver.save(sess=sess, save_path=model_save_path, global_step=step)

      if False and step % FLAGS.step_per_eval == 0:
        eval_input_fn = input_fn.input_fn(eval_tf_record, FLAGS.batch_size, False, channel_size=FLAGS.channel_size )
        eval_next_batch = eval_input_fn.get_next()
        all_distance =  []
        while True:
          try:
            eval_batch = sess.run(eval_next_batch)
            images = batch['images']
            labels = batch['labels']
            sequence_length = batch['sequence_length']
            train_distance = sess.run([edit_distance], 
                    feed_dict={
                      crnn_model.images:images, 
                      crnn_model.labels:labels, 
                      crnn_model.keep_prob:1.0, 
                      crnn_model.is_training:True, 
                      crnn_model.sequence_length: sequence_length})
            all_distance.append(train_distance[0])
          except tf.errors.OutOfRangeError as e:
            print("eval acc: ", 1 - np.mean(np.array(all_distance)))
            break
コード例 #7
0
    image = image.to(device)
    # 将图片数据输入模型
    output = model(image)
    output_log_softmax = F.log_softmax(output, dim=-1)
    # print('predict result:{}\n'.format(decode(output_log_softmax)))
    # 对结果进行解码
    pred_labels = ctc_greedy_decoder(output_log_softmax)
    pred_texts = ''.join(Idx2Word(pred_labels))
    print('predict result: {}\n'.format(pred_texts))


if __name__ == '__main__':
    transform = transforms.ToTensor()

    # 模型定义
    model = model.CRNN(num_classes)
    model = model.to(device)
    print(model)

    # 读取参数
    if os.path.exists('checkpoint.pth.tar'):
        checkpoint = torch.load('checkpoint.pth.tar', map_location=device)
        model.load_state_dict(checkpoint['state_dict'])
        print('model has restored')

    # 依次推断文件夹中每张图片
    files = sorted(os.listdir(IMG_ROOT))
    # for循环依次读取文件夹中的每张图片
    for file in files:
        image_path = os.path.join(IMG_ROOT, file)
        # 将图片读取进来
コード例 #8
0
ファイル: eval.py プロジェクト: zzmcdc/document-ocr
def eval():
    tf.reset_default_graph()
    char_map_dict = load_char_map()
    config = Config()
    config.num_classes = len(char_map_dict) + 1
    id_to_char = {v: k for k, v in char_map_dict.items()}
    crnn_net = model.CRNN(config)
    with open(FLAGS.image_list, 'r') as fd:
        image_names = []
        true_labels = []
        for i, line in enumerate(fd):
            seg = " "
            line = line.strip().split(seg)
            image_names.append(line[0])
            true_labels.append(seg.join(line[1:]))

    index_list = random.choices(list(range(len(image_names))), k=50)
    image_names = [image_names[i] for i in index_list]
    labels = [true_labels[i] for i in index_list]

    saver = tf.train.Saver()
    save_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    with tf.Session() as sess:
        saver.restore(sess=sess, save_path=save_path)
        print("restored from %s" % (save_path))
        decoded, log_prob = tf.nn.ctc_greedy_decoder(crnn_net.logits,
                                                     crnn_net.sequence_length,
                                                     merge_repeated=True)
        if FLAGS.export:
            tensor_image_input_info = tf.saved_model.utils.build_tensor_info(
                crnn_net.images)
            tensor_seq_len_input_info = tf.saved_model.utils.build_tensor_info(
                crnn_net.sequence_length)
            tensor_is_traing_info = tf.saved_model.utils.build_tensor_info(
                crnn_net.is_training)
            tensor_keep_prob = tf.saved_model.utils.build_tensor_info(
                crnn_net.keep_prob)
            output_info = tf.saved_model.utils.build_tensor_info(decoded[0])
            signature = tf.saved_model.signature_def_utils.build_signature_def(
                inputs={
                    'images': tensor_image_input_info,
                    'sequence_length': tensor_seq_len_input_info,
                    "is_training": tensor_is_traing_info,
                    "keep_prob": tensor_keep_prob
                },
                outputs={'decoded': output_info})

            ex_dir = str(int(time.time()))
            builder = tf.saved_model.builder.SavedModelBuilder(
                "./all_exported_models/%s/" % (ex_dir, ))
            builder.add_meta_graph_and_variables(
                sess=sess,
                tags=[tag_constants.SERVING],
                signature_def_map={"predict": signature})
            builder.save()
            print("exported model at ")
        ignore = 0
        error_count = 0
        total_count = 0
        for i, image_name in enumerate(image_names):
            image_path = os.path.join(FLAGS.image_dir, image_name)
            if FLAGS.channel_size == 3:
                image = cv2.imread(image_path)
            else:
                image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            if image is None:
                print('ignore')
                ignore += 1
                continue
            h, w = image.shape[:2]
            height = _IMAGE_HEIGHT
            width = int(w * height / h)
            image = cv2.resize(image, (width, height))
            image = np.array(image, dtype=np.float32)
            image = image / 255.0
            seq_len = np.array([width / 4], dtype=np.int32)
            print("length: ", seq_len)
            if FLAGS.channel_size == 1:
                image = image[:, :, np.newaxis]
                cv2.imwrite("test.png", image * 255.0)
            image = np.expand_dims(image, axis=0)
            start = time.time()
            print(image.shape)
            logit, preds, prob = sess.run(
                [crnn_net.logits, decoded, log_prob],
                feed_dict={
                    crnn_net.images: image,
                    crnn_net.sequence_length: seq_len,
                    crnn_net.keep_prob: 1.0,
                    crnn_net.is_training: False
                })
            preds = _sparse_matrix_to_list(preds[0], id_to_char)
            cost_time = time.time() - start
            res_text = preds[0]
            res_text = merge_text(res_text)
            err_count = Levenshtein.distance(labels[i], res_text)
            total_count += len(labels[i])
            error_count += err_count
            print(image_name)
            print(
                'true label {:s} \n predict result: {:s} cost:{:f} \n error_count:{:d}'
                .format(
                    labels[i],
                    preds[0],
                    cost_time,
                    err_count,
                ))
        print(1 - 1.0 * error_count / total_count)
コード例 #9
0
ファイル: train.py プロジェクト: oujunke/keras_crnn
parser.add_argument('--epochs', type=int, default=100,
                    help='upper epoch limit')

args = parser.parse_args()

# data pre-processing
dataset = Dataset(args)
dataset.data_preprocess()
dataset.rescale()
dataset.generate_key()
dataset.random_get_val()

train = dataset.generate(args.json_path, args.save_path, args.key_path, args.batch_size, args.max_label_length, (args.image_height, args.image_width))
val = dataset.generate(args.json_val_path, args.save_path, args.key_path, args.batch_size, args.max_label_length, (args.image_height, args.image_width))

crnn = model.CRNN(args)
y_pred = crnn.model()
loss = crnn.get_loss(y_pred)
inputs = crnn.inputs
labels = crnn.labels
input_length = crnn.input_length
label_length = crnn.label_length
model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss)
adam = Adam()
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=adam,metrics=['accuracy'])
checkpoint = ModelCheckpoint(args.model_path + (r'weights-{epoch:02d}.hdf5'),
                           save_weights_only=True)
earlystop = EarlyStopping(patience=10)
tensorboard = TensorBoard(args.model_path + '/tflog',write_graph=True)

res = model.fit_generator(train,