コード例 #1
0
def main(opts):
  # alphabet = '0123456789.'
  nclass = len(alphabet) + 1
  model_name = 'E2E-CRNN'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))

  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
  optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
  step_start = 0

  ### 第一种:只修改conv11的维度
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     print('loading pretrained model from %s' % opts.model)
  #     pretrained_model = OwnModel(attention=True, nclass=12)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()
  #
  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
  #     model_dict.update(pretrained_dict)
  #     net.load_state_dict(model_dict)

  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)

  ## ICDAR2015数据集
  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate, num_workers=4)

  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  for step in range(step_start, opts.max_iters):

    for index, date in enumerate(e2edataloader):
      im_data, gtso, lbso = date
      im_data = im_data.cuda()

      try:
    loss= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)

    net.zero_grad()
    # optimizer.zero_grad()
    loss.backward()
    optimizer.step()
      except:
    import sys, traceback
    traceback.print_exc(file=sys.stdout)
    pass


      if index % disp_interval == 0:
コード例 #2
0
def main(opts):
  alphabet = '0123456789.'
  nclass = len(alphabet) + 1
  model_name = 'crnn'
  net = CRNN(nclass)
  print("Using {0}".format(model_name))

  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)

  ## 数据集
  converter = strLabelConverter(alphabet)
  dataset = ImgDataset(
      root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
      csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
      transform=None,
      target_transform=converter.encode
  )
  ocrdataloader = torch.utils.data.DataLoader(
      dataset, batch_size=1, shuffle=False, collate_fn=own_collate
  )

  num_count = 0
  net = net.eval()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  for step in range(len(dataset)):

    try:
    data = next(data_iter)
    except:
    data_iter = iter(ocrdataloader)
    data = next(data_iter)

    im_data, gt_boxes, text = data
    im_data = im_data.cuda()

    try:
      res = process_crnn(im_data, gt_boxes, text, net, ctc_loss, converter, training=False)

      pred, target = res
      if pred == target[0]:
    num_count += 1
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      pass


    print('correct/total:%d/%d'%(num_count, len(dataset)))
コード例 #3
0
def main(opts):

  nclass = len(alphabet) + 1
  model_name = 'E2E-MLT'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))
  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  ### 第一种:只修改conv11的维度 
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     # 载入预训练模型
  #     print('loading pretrained model from %s' % opts.model)
  #     # pretrained_model = OwnModel(attention=True, nclass=7325)
  #     pretrained_model = ModelResNetSep2(attention=True, nclass=7500)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()
  #
  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'conv11' not in k and 'rnn' not in k}
  #     # 2. overwrite entries in the existing state dict
  #     model_dict.update(pretrained_dict)
  #     # 3. load the new state dict
  #     net.load_state_dict(model_dict)

  ### 第二种:直接接着前面训练
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  ### 
  
  step_start = 0
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=4, shuffle=True, collate_fn=E2Ecollate)
  
  train_loss = 0
  bbox_loss, seg_loss, angle_loss = 0., 0., 0.
  cnt = 0
  ctc_loss_val = 0
  ctc_loss_val2 = 0
  box_loss_val = 0
  gt_g_target = 0
  gt_g_proc = 0
  
  
  for step in range(step_start, opts.max_iters):

    loss = 0

    # batch
    images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(data_generator)
    im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=opts.cuda)
    # im_data = torch.from_numpy(images).type(torch.FloatTensor).permute(0, 3, 1, 2).cuda()           # permute(0,3,1,2)和cuda的先后顺序有影响
    start = timeit.timeit()
    try:
      seg_pred, roi_pred, angle_pred, features = net(im_data)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
    end = timeit.timeit()
    
    # for EAST loss
    smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
    training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=opts.cuda)
    angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4], is_cuda=opts.cuda)
    geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]], is_cuda=opts.cuda)
    
    try:
      loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred, angle_gt, roi_pred, geo_gt)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
      
    bbox_loss += net.box_loss_value.data.cpu().numpy() 
    seg_loss += net.segm_loss_value.data.cpu().numpy()
    angle_loss += net.angle_loss_value.data.cpu().numpy()  
    train_loss += loss.data.cpu().numpy()
    
       
    try:
      # 10000步之前都是用文字的标注区域训练的
      if step > 10000 or True: #this is just extra augumentation step ... in early stage just slows down training
        # ctcl, gt_target , gt_proc = process_boxes(images, im_data, seg_pred[0], roi_pred[0], angle_pred[0], score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug=opts.debug)
        ctcl= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
        gt_target = 1
        gt_proc = 1

        ctc_loss_val += ctcl.data.cpu().numpy()[0]
        loss = ctcl
        gt_g_target = gt_target
        gt_g_proc = gt_proc
        train_loss += ctcl.item()
      
      # -训练ocr识别部分的时候,采用一个data_generater生成
      # imageso, labels, label_length = next(dg_ocr)              # 其中应该有对倾斜文本的矫正
      # im_data_ocr = net_utils.np_to_variable(imageso, is_cuda=opts.cuda).permute(0, 3, 1, 2)
      # features = net.forward_features(im_data_ocr)
      # labels_pred = net.forward_ocr(features)
      # probs_sizes =  torch.IntTensor( [(labels_pred.permute(2,0,1).size()[0])] * (labels_pred.permute(2,0,1).size()[1]) )
      # label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int() )
      # labels = torch.IntTensor( torch.from_numpy(np.array(labels)).int() )
      # loss_ocr = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data_ocr.size(0) * 0.5
      # loss_ocr.backward()
      # ctc_loss_val2 += loss_ocr.item()

      net.zero_grad()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      pass


    cnt += 1
    if step % disp_interval == 0:
      if opts.debug:
        
        segm = seg_pred[0].data.cpu()[0].numpy()
        segm = segm.squeeze(0)
        cv2.imshow('segm_map', segm)
        
        segm_res = cv2.resize(score_maps[0], (images.shape[2], images.shape[1]))
        mask = np.argwhere(segm_res > 0)
        
        x_data = im_data.data.cpu().numpy()[0]
        x_data = x_data.swapaxes(0, 2)
        x_data = x_data.swapaxes(0, 1)
        
        x_data += 1
        x_data *= 128
        x_data = np.asarray(x_data, dtype=np.uint8)
        x_data = x_data[:, :, ::-1]
        
        im_show = x_data
        try:
          im_show[mask[:, 0], mask[:, 1], 1] = 255 
          im_show[mask[:, 0], mask[:, 1], 0] = 0 
          im_show[mask[:, 0], mask[:, 1], 2] = 0
        except:
          pass
        
        cv2.imshow('img0', im_show) 
        cv2.imshow('score_maps', score_maps[0] * 255)
        cv2.imshow('train_mask', training_masks[0] * 255)
        cv2.waitKey(10)
      
      train_loss /= cnt
      bbox_loss /= cnt
      seg_loss /= cnt
      angle_loss /= cnt
      ctc_loss_val /= cnt
      ctc_loss_val2 /= cnt
      box_loss_val /= cnt
      try:
        print('epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, gt_t/gt_proc:[%d/%d] lv2 %.3f' % (
          step / batch_per_epoch, step, train_loss, bbox_loss, seg_loss, angle_loss, ctc_loss_val, gt_g_target, gt_g_proc , ctc_loss_val2))
      except:
        import sys, traceback
        traceback.print_exc(file=sys.stdout)
        pass
    
      train_loss = 0
      bbox_loss, seg_loss, angle_loss = 0., 0., 0.
      cnt = 0
      ctc_loss_val = 0
      good_all = 0
      gt_all = 0
      box_loss_val = 0
      
    # for save mode
    #  validate(opts.valid_list, net)
    if step > step_start and (step % batch_per_epoch == 0):
      save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
      state = {'step': step,
               'learning_rate': learning_rate,
              'state_dict': net.state_dict(),
              'optimizer': optimizer.state_dict()}
      torch.save(state, save_name)
      print('save model: {}'.format(save_name))
コード例 #4
0
def main(opts):
  # alphabet = '0123456789.'
  nclass = len(alphabet) + 1
  model_name = 'E2E-CRNN'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))

  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
  optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
  step_start = 0

  ### 第一种:只修改conv11的维度 
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     print('loading pretrained model from %s' % opts.model)
  #     pretrained_model = OwnModel(attention=True, nclass=12)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()

  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
  #     model_dict.update(pretrained_dict)
  #     net.load_state_dict(model_dict)

  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)

  ## 数据集
  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=False, collate_fn=E2Ecollate, num_workers=4)
  
  # 电表数据集
  # converter = strLabelConverter(alphabet)
  # dataset = ImgDataset(
  #     root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
  #     csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
  #     transform=None,
  #     target_transform=converter.encode
  # )
  # ocrdataloader = torch.utils.data.DataLoader(
  #     dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
  # )
  
  net.eval()
  num_count = 0

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  for index, date in enumerate(e2edataloader):
    im_data, gtso, lbso = date
    im_data = im_data.cuda()
      
    try:
      with torch.no_grad():
    res = process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=False)

    pred, target = res
    target = ''.join(target)
    if pred == target:
      num_count += 1
    except:
コード例 #5
0
def main(opts):

  nclass = len(alphabet) + 1
  model_name = 'E2E-MLT'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))
  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  ### 第一种:只修改conv11的维度 
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     # 载入预训练模型
  #     print('loading pretrained model from %s' % opts.model)
  #     # pretrained_model = OwnModel(attention=True, nclass=7325)
  #     pretrained_model = ModelResNetSep2(attention=True, nclass=7500)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()
  #
  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'conv11' not in k and 'rnn' not in k}
  #     # 2. overwrite entries in the existing state dict
  #     model_dict.update(pretrained_dict)
  #     # 3. load the new state dict
  #     net.load_state_dict(model_dict)

  ### 第二种:直接接着前面训练
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  ### 
  
  step_start = 0
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=4, shuffle=True, collate_fn=E2Ecollate)
  
  train_loss = 0
  bbox_loss, seg_loss, angle_loss = 0., 0., 0.
  cnt = 0
  ctc_loss_val = 0
  ctc_loss_val2 = 0
  box_loss_val = 0
  gt_g_target = 0
  gt_g_proc = 0
  
  
  for step in range(step_start, opts.max_iters):

    loss = 0

    # batch
    images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(data_generator)
    im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=opts.cuda)
    # im_data = torch.from_numpy(images).type(torch.FloatTensor).permute(0, 3, 1, 2).cuda()       # permute(0,3,1,2)和cuda的先后顺序有影响
    start = timeit.timeit()
    try:
      seg_pred, roi_pred, angle_pred, features = net(im_data)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
    end = timeit.timeit()
    
    # for EAST loss
    smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
    training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=opts.cuda)
    angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4], is_cuda=opts.cuda)
    geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]], is_cuda=opts.cuda)
    
    try:
      loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred, angle_gt, roi_pred, geo_gt)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
      
    bbox_loss += net.box_loss_value.data.cpu().numpy() 
    seg_loss += net.segm_loss_value.data.cpu().numpy()
    angle_loss += net.angle_loss_value.data.cpu().numpy()  
    train_loss += loss.data.cpu().numpy()
    
       
    try:
      # 10000步之前都是用文字的标注区域训练的
      if step > 10000 or True: #this is just extra augumentation step ... in early stage just slows down training
    # ctcl, gt_target , gt_proc = process_boxes(images, im_data, seg_pred[0], roi_pred[0], angle_pred[0], score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug=opts.debug)
    ctcl= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
    gt_target = 1
    gt_proc = 1

    ctc_loss_val += ctcl.data.cpu().numpy()[0]
    loss = ctcl
    gt_g_target = gt_target
    gt_g_proc = gt_proc
    train_loss += ctcl.item()
      
      # -训练ocr识别部分的时候,采用一个data_generater生成
      # imageso, labels, label_length = next(dg_ocr)          # 其中应该有对倾斜文本的矫正
      # im_data_ocr = net_utils.np_to_variable(imageso, is_cuda=opts.cuda).permute(0, 3, 1, 2)
      # features = net.forward_features(im_data_ocr)
      # labels_pred = net.forward_ocr(features)
      # probs_sizes =  torch.IntTensor( [(labels_pred.permute(2,0,1).size()[0])] * (labels_pred.permute(2,0,1).size()[1]) )
      # label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int() )
      # labels = torch.IntTensor( torch.from_numpy(np.array(labels)).int() )
      # loss_ocr = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data_ocr.size(0) * 0.5
      # loss_ocr.backward()
      # ctc_loss_val2 += loss_ocr.item()

      net.zero_grad()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    except:
コード例 #6
0
def main(opts):
    # alphabet = '0123456789.'
    nclass = len(alphabet) + 1
    model_name = 'E2E-CRNN'
    net = OwnModel(attention=True, nclass=nclass)
    print("Using {0}".format(model_name))

    if opts.cuda:
        net.cuda()
    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    optimizer = optim.Adam(net.parameters(),
                           lr=opts.base_lr,
                           betas=(0.5, 0.999))
    step_start = 0

    ### 第一种:只修改conv11的维度
    # model_dict = net.state_dict()
    # if os.path.exists(opts.model):
    #     print('loading pretrained model from %s' % opts.model)
    #     pretrained_model = OwnModel(attention=True, nclass=12)
    #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
    #     pretrained_dict = pretrained_model.state_dict()
    #
    #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
    #     model_dict.update(pretrained_dict)
    #     net.load_state_dict(model_dict)

    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net,
                                                       optimizer)

    ## 数据集
    e2edata = E2Edataset(train_list=opts.train_list)
    e2edataloader = torch.utils.data.DataLoader(e2edata,
                                                batch_size=opts.batch_size,
                                                shuffle=True,
                                                collate_fn=E2Ecollate,
                                                num_workers=4)

    # 电表数据集
    # converter = strLabelConverter(alphabet)
    # dataset = ImgDataset(
    #     root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
    #     csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
    #     transform=None,
    #     target_transform=converter.encode
    # )
    # ocrdataloader = torch.utils.data.DataLoader(
    #     dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
    # )

    net.train()

    converter = strLabelConverter(alphabet)
    ctc_loss = CTCLoss()

    for step in range(step_start, opts.max_iters):

        for index, date in enumerate(e2edataloader):
            im_data, gtso, lbso = date
            im_data = im_data.cuda()

            try:
                loss = process_crnn(im_data,
                                    gtso,
                                    lbso,
                                    net,
                                    ctc_loss,
                                    converter,
                                    training=True)

                net.zero_grad()
                # optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            if index % disp_interval == 0:
                try:
                    print('epoch:%d || step:%d || loss %.4f' %
                          (step, index, loss))
                except:
                    import sys, traceback
                    traceback.print_exc(file=sys.stdout)
                    pass

        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))
コード例 #7
0
def main(opts):
  alphabet = '0123456789.'
  nclass = len(alphabet) + 1
  model_name = 'crnn'
  net = CRNN(nclass)
  print("Using {0}".format(model_name))

  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)

  ## 数据集
  converter = strLabelConverter(alphabet)
  dataset = ImgDataset(
      root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
      csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
      transform=None,
      target_transform=converter.encode
  )
  ocrdataloader = torch.utils.data.DataLoader(
      dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
  )
  
  step_start = 0
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  for step in range(step_start, opts.max_iters):

    try:
    data = next(data_iter)
    except:
    data_iter = iter(ocrdataloader)
    data = next(data_iter)
    
    im_data, gt_boxes, text = data
    im_data = im_data.cuda()
       
    try:
      loss= process_crnn(im_data, gt_boxes, text, net, ctc_loss, converter, training=True)

      net.zero_grad()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      pass


    if step % disp_interval == 0:
      try:
    print('step:%d || loss %.4f' % (step, loss))
      except:
    import sys, traceback
    traceback.print_exc(file=sys.stdout)
    pass
    
    if step > step_start and (step % batch_per_epoch == 0):
      save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
      state = {'step': step,
           'learning_rate': learning_rate,
          'state_dict': net.state_dict(),
          'optimizer': optimizer.state_dict()}
      torch.save(state, save_name)
      print('save model: {}'.format(save_name))