예제 #1
0
파일: train.py 프로젝트: LHCyGan/CTC_CRNN
def trainBatch(net, criterion, optimizer):
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

    preds = crnn(image)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
    cost = criterion(preds, text, preds_size, length) / batch_size
    crnn.zero_grad()
    cost.backward()
    optimizer.step()
    return cost
예제 #2
0
def main(opts):
  alphabet = '0123456789.'
  nclass = len(alphabet) + 1
  model_name = 'crnn'
  net = CRNN(nclass)
  print("Using {0}".format(model_name))

  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)

  ## 数据集
  converter = strLabelConverter(alphabet)
  dataset = ImgDataset(
      root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
      csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
      transform=None,
      target_transform=converter.encode
  )
  ocrdataloader = torch.utils.data.DataLoader(
      dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
  )
  
  step_start = 0
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  for step in range(step_start, opts.max_iters):

    try:
    data = next(data_iter)
    except:
    data_iter = iter(ocrdataloader)
    data = next(data_iter)
    
    im_data, gt_boxes, text = data
    im_data = im_data.cuda()
       
    try:
      loss= process_crnn(im_data, gt_boxes, text, net, ctc_loss, converter, training=True)

      net.zero_grad()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      pass


    if step % disp_interval == 0:
      try:
    print('step:%d || loss %.4f' % (step, loss))
      except:
    import sys, traceback
    traceback.print_exc(file=sys.stdout)
    pass
    
    if step > step_start and (step % batch_per_epoch == 0):
      save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
      state = {'step': step,
           'learning_rate': learning_rate,
          'state_dict': net.state_dict(),
          'optimizer': optimizer.state_dict()}
      torch.save(state, save_name)
      print('save model: {}'.format(save_name))