예제 #1
0
def eval_detection(opts, net=None):
  if net == None:
    net = OctShuffleMLT(attention=True)
    net_utils.load_net(opts.model, net)
    if opts.cuda:
      net.cuda()

  images, gt_boxes = load_annotation(opts.eval_list)  
  true_positives = 0
  false_positives = 0
  false_negatives = 0
  
  for i in range(images.shape[0]):
    image = np.expand_dims(images[i], axis=0)
    image_boxes_gt = np.array(gt_boxes[i])

    im_data = net_utils.np_to_variable(image, is_cuda=opts.cuda).permute(0, 3, 1, 2)
    seg_pred, rboxs, angle_pred, features = net(im_data)
    
    rbox = rboxs[0].data.cpu()[0].numpy()
    rbox = rbox.swapaxes(0, 1)
    rbox = rbox.swapaxes(1, 2)
    angle_pred = angle_pred[0].data.cpu()[0].numpy()
    segm = seg_pred[0].data.cpu()[0].numpy()
    segm = segm.squeeze(0)

    boxes =  get_boxes(segm, rbox, angle_pred, opts.segm_thresh)

    if (opts.debug):
      print(boxes.shape)
      print(image_boxes_gt.shape)
      print("============")

    false_positives += boxes.shape[0]
    false_negatives += image_boxes_gt.shape[0]
    for box in boxes:
      b = box[0:8].reshape(4,-1)
      poly = Polygon.Polygon(b)
      for box_gt in image_boxes_gt:
        b_gt = box_gt[0:8].reshape(4,-1)
        poly_gt = Polygon.Polygon(b_gt)
        intersection = poly_gt | poly
        union = poly_gt & poly
        iou = (intersection.area()+1.0) / (union.area()+1.0)-1.0
        if iou > 0.5:
          true_positives+=1
          false_negatives-=1
          false_positives-=1
          image_boxes_gt = np.array([bgt for bgt in image_boxes_gt if not np.array_equal(bgt, box_gt)])
          break
  print("tp: {} fp: {} fn: {}".format(true_positives, false_positives, false_negatives))
  precision = true_positives / (true_positives+false_positives)
  recall = true_positives / (true_positives+false_negatives)
  f_score = 2*precision*recall/(precision+recall)
  print("PRECISION: {} \t RECALL: {} \t F SCORE: {}".format(precision, recall, f_score))
예제 #2
0
def dice_loss(segm_preds, score_maps, training_masks, multi_scale=False):

    score_maps = np.asarray(score_maps, dtype=np.uint8)
    training_masks = np.asarray(training_masks, dtype=np.uint8)

    smaps_var = net_utils.np_to_variable(score_maps, is_cuda=False)
    training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=False)
    segm_pred = segm_preds[0].squeeze(1)
    segm_pred1 = segm_preds[1].squeeze(1)
    inp = segm_pred * training_mask_var
    target = smaps_var * training_mask_var

    smooth = 1.
    iflat = inp.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    result = -((2. * intersection + smooth) /
               (iflat.sum() + tflat.sum() + smooth))
    if multi_scale:
        iou_gts = F.interpolate(smaps_var.unsqueeze(1),
                                size=(segm_pred1.size(1), segm_pred1.size(2)),
                                mode='bilinear',
                                align_corners=True).squeeze(1)
        iou_masks = F.interpolate(training_mask_var.unsqueeze(1),
                                  size=(segm_pred1.size(1),
                                        segm_pred1.size(2)),
                                  mode='bilinear',
                                  align_corners=True).squeeze(1)
        inp2 = segm_pred1 * iou_masks
        target2 = iou_gts * iou_masks

        # smooth = 1.
        iflat2 = inp2.view(-1)
        tflat2 = target2.view(-1)
        intersection2 = (iflat2 * tflat2).sum()
        result += -((2. * intersection2 + smooth) /
                    (iflat2.sum() + tflat2.sum() + smooth))

    return result
예제 #3
0
def main(opts):

    model_name = 'OCT-E2E-MLT'
    net = OctMLT(attention=True)
    print("Using {0}".format(model_name))

    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net)

    if opts.cuda:
        net.cuda()

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_list,
                                        geo_type=opts.geo_type)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True)

    train_loss = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 0
    ctc_loss = CTCLoss()

    ctc_loss_val = 0
    box_loss_val = 0
    good_all = 0
    gt_all = 0

    best_step = step_start
    best_loss = 1000000
    best_model = net.state_dict()
    best_optimizer = optimizer.state_dict()
    best_learning_rate = learning_rate
    max_patience = 3000
    early_stop = False

    for step in range(step_start, opts.max_iters):

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

        bbox_loss += net.box_loss_value.data.cpu().numpy()
        seg_loss += net.segm_loss_value.data.cpu().numpy()
        angle_loss += net.angle_loss_value.data.cpu().numpy()

        train_loss += loss.data.cpu().numpy()
        optimizer.zero_grad()

        try:

            if step > 10000:  #this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(features)

            probs_sizes = torch.IntTensor(
                [(labels_pred.permute(2, 0, 1).size()[0])] *
                (labels_pred.permute(2, 0, 1).size()[1]))
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int())
            labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
            loss_ocr = ctc_loss(labels_pred.permute(2, 0,
                                                    1), labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            loss.backward()

            optimizer.step()
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass
        cnt += 1
        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                x_data += 1
                x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            box_loss_val /= cnt

            if train_loss < best_loss:
                best_step = step
                best_model = net.state_dict()
                best_loss = train_loss
                best_learning_rate = learning_rate
                best_optimizer = optimizer.state_dict()
            if best_step - step > max_patience:
                print("Early stopped criteria achieved.")
                save_name = os.path.join(
                    opts.save_path,
                    'BEST_{}_{}.h5'.format(model_name, best_step))
                state = {
                    'step': best_step,
                    'learning_rate': best_learning_rate,
                    'state_dict': best_model,
                    'optimizer': best_optimizer
                }
                torch.save(state, save_name)
                print('save model: {}'.format(save_name))
                opts.max_iters = step
                early_stop = True
            try:
                print(
                    'epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f in %.3f'
                    % (step / batch_per_epoch, step, train_loss, bbox_loss,
                       seg_loss, angle_loss, ctc_loss_val,
                       good_all / max(1, gt_all), end - start))
                print('max_memory_allocated {}'.format(
                    torch.cuda.max_memory_allocated()))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        #if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'max_memory_allocated': torch.cuda.max_memory_allocated()
            }
            torch.save(state, save_name)
            print('save model: {}\tmax memory: {}'.format(
                save_name, torch.cuda.max_memory_allocated()))
    if not early_stop:
        save_name = os.path.join(opts.save_path, '{}.h5'.format(model_name))
        state = {
            'step': step,
            'learning_rate': learning_rate,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(state, save_name)
        print('save model: {}'.format(save_name))
예제 #4
0
def test(net,
         codec,
         args,
         list_file='/home/busta/data/icdar_ch8_validation/ocr_valid.txt',
         norm_height=32,
         max_samples=1000000):

    codec_rev = {}
    index = 4
    for i in range(0, len(codec)):
        codec_rev[codec[i]] = index
        index += 1

    net = net.eval()
    #list_file = '/mnt/textspotter/tmp/90kDICT32px/train_list.txt'
    #list_file = '/home/busta/data/Challenge2_Test_Task3_Images/gt.txt'
    #list_file = '/home/busta/data/90kDICT32px/train_icdar_ch8.txt'
    fout = open('/tmp/ch8_valid.txt', 'w')
    fout_ocr = open('/tmp/ocr_valid.txt', 'w')

    dir_name = os.path.dirname(list_file)
    images = []
    with open(list_file, "r") as ins:
        for line in ins:
            images.append(line.strip())
            #if len(images) > 1000:
            #  break

    scripts = [
        '', 'DIGIT', 'LATIN', 'ARABIC', 'BENGALI', 'HANGUL', 'CJK', 'HIRAGANA',
        'KATAKANA'
    ]

    conf_matrix = np.zeros((len(scripts), len(scripts)), dtype=np.int)

    gt_script = {}
    ed_script = {}
    correct_ed1_script = {}
    correct_script = {}
    count_script = {}
    for scr in scripts:
        gt_script[scr] = 0
        ed_script[scr] = 0
        correct_script[scr] = 0
        correct_ed1_script[scr] = 0
        count_script[scr] = 0

    it = 0
    it2 = 0
    correct = 0
    correct_ed1 = 0
    ted = 0
    gt_all = 0
    images_count = 0
    bad_words = []

    for img in images:

        imageNo = it2
        #imageNo = random.randint(0, len(images) - 1)
        if imageNo >= len(images) or imageNo > max_samples:
            break

        image_name = img

        spl = image_name.split(",")
        delim = ","
        if len(spl) == 1:
            spl = image_name.split(" ")
            delim = " "
        image_name = spl[0].strip()
        gt_txt = ''
        if len(spl) > 1:
            gt_txt = spl[1].strip()
            if len(spl) > 2:
                gt_txt += delim + spl[2]

            if len(gt_txt) > 1 and gt_txt[0] == '"' and gt_txt[-1] == '"':
                gt_txt = gt_txt[1:len(gt_txt) - 1]

        it2 += 1
        if len(gt_txt) == 0:
            print(images[imageNo])
            continue

        if image_name[-1] == ',':
            image_name = image_name[0:-1]

        img_nameo = image_name
        image_name = '{0}/{1}'.format(dir_name, image_name)
        img = cv2.imread(image_name)

        if img is None:
            print(image_name)
            continue

        scale = norm_height / float(img.shape[0])
        width = int(img.shape[1] * scale)
        width = max(8, int(round(width / 4)) * 4)

        scaled = cv2.resize(img, (int(width), norm_height))
        #scaled = scaled[:, :, ::-1]
        scaled = np.expand_dims(scaled, axis=0)

        scaled = np.asarray(scaled, dtype=np.float)
        scaled /= 128
        scaled -= 1

        try:
            scaled_var = net_utils.np_to_variable(scaled,
                                                  is_cuda=args.cuda).permute(
                                                      0, 3, 1, 2)
            x = net.forward_features(scaled_var)
            ctc_f = net.forward_ocr(x)
            ctc_f = ctc_f.data.cpu().numpy()
            ctc_f = ctc_f.swapaxes(1, 2)

            labels = ctc_f.argmax(2)
            det_text, conf, dec_s, _ = print_seq_ext(labels[0, :], codec)
        except:
            print('bad image')
            det_text = ''

        det_text = det_text.strip()
        gt_txt = gt_txt.strip()

        try:
            if 'ARABIC' in ud.name(gt_txt[0]):
                #gt_txt = gt_txt[::-1]
                det_text = det_text[::-1]
        except:
            continue

        it += 1

        scr_count = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        scr_count = np.array(scr_count)

        for c_char in gt_txt:
            assigned = False
            for idx, scr in enumerate(scripts):
                if idx == 0:
                    continue
                symbol_name = ud.name(c_char)
                if scr in symbol_name:
                    scr_count[idx] += 1
                    assigned = True
                    break
            if not assigned:
                scr_count[0] += 1

        maximum_indices = np.where(scr_count == np.max(scr_count))
        script = scripts[maximum_indices[0][0]]

        det_count = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        det_count = np.array(det_count)
        for c_char in det_text:
            assigned = False
            for idx, scr in enumerate(scripts):
                if idx == 0:
                    continue
                try:
                    symbol_name = ud.name(c_char)
                    if scr in symbol_name:
                        det_count[idx] += 1
                        assigned = True
                        break
                except:
                    pass
            if not assigned:
                det_count[0] += 1

        maximum_indices_det = np.where(det_count == np.max(det_count))
        script_det = scripts[maximum_indices_det[0][0]]

        conf_matrix[maximum_indices[0][0], maximum_indices_det[0][0]] += 1

        edit_dist = distance(det_text.lower(), gt_txt.lower())
        ted += edit_dist
        gt_all += len(gt_txt)

        gt_script[script] += len(gt_txt)
        ed_script[script] += edit_dist
        images_count += 1

        fout_ocr.write('{0}, "{1}"\n'.format(os.path.basename(image_name),
                                             det_text.strip()))

        if det_text.lower() == gt_txt.lower():
            correct += 1
            correct_ed1 += 1
            correct_script[script] += 1
            correct_ed1_script[script] += 1
        else:
            if edit_dist == 1:
                correct_ed1 += 1
                correct_ed1_script[script] += 1
            image_prev = "<img src=\"{0}\" height=\"32\" />".format(img_nameo)
            bad_words.append(
                (gt_txt, det_text, edit_dist, image_prev, img_nameo))
            print('{0} - {1} / {2:.2f} - {3:.2f}'.format(
                det_text, gt_txt, correct / float(it), ted / 3.0))

        count_script[script] += 1
        fout.write('{0}|{1}|{2}|{3}\n'.format(os.path.basename(image_name),
                                              gt_txt, det_text, edit_dist))

    print('Test accuracy: {0:.3f}, {1:.2f}, {2:.3f}'.format(
        correct / float(images_count), ted / 3.0, ted / float(gt_all)))

    itf = open("per_script_accuracy.csv", "w")
    itf.write(
        'Script & Accuracy & Edit Distance & ed1 & Ch instances & Im Instances \\\\\n'
    )
    for scr in scripts:
        correct_scr = correct_script[scr]
        correct_scr_ed1 = correct_ed1_script[scr]
        all = count_script[scr]
        ted_scr = ed_script[scr]
        gt_all_scr = gt_script[scr]
        print(' Script:{3} Acc : {0:.3f}, {1:.2f}, {2:.3f}, {4}'.format(
            correct_scr / float(max(all, 1)), ted_scr / 3.0,
            ted_scr / float(max(gt_all_scr, 1)), scr, gt_all_scr))

        itf.write(
            '{0} & {1:.3f} & {5:.3f} &  {2:.3f} & {3} & {4} \\\\\n'.format(
                scr.title(), correct_scr / float(max(all, 1)),
                ted_scr / float(max(gt_all_scr, 1)), gt_all_scr, all,
                correct_scr_ed1 / float(max(all, 1))))

    itf.write('{0} & {1:.3f} & {5:.3f} &  {2:.3f} & {3} & {4} \\\\\n'.format(
        'Total', correct / float(max(images_count, 1)),
        ted / float(max(gt_all, 1)), gt_all, images_count,
        correct_ed1 / float(max(images_count, 1))))
    itf.close()

    print(conf_matrix)
    np.savetxt("conf_matrix.csv",
               conf_matrix,
               delimiter=' & ',
               fmt='%d',
               newline=' \\\\\n')

    itf = open("conf_matrix_out.csv", "w")
    itf.write(' & ')
    delim = ""
    for scr in scripts:
        itf.write(delim)
        itf.write(scr.title())
        delim = " & "
    itf.write('\\\\\n')

    script_no = 0
    with open("conf_matrix.csv", "r") as ins:
        for line in ins:
            line = scripts[script_no].title() + " & " + line
            itf.write(line)
            script_no += 1
            if script_no >= len(scripts):
                break

    fout.close()
    fout_ocr.close()
    net.train()

    pd.options.display.max_rows = 9999
    #pd.options.display.max_cols = 9999

    if len(bad_words) > 0:
        wworst = sorted(bad_words, key=lambda x: x[2])

        ww = np.asarray(wworst, np.object)
        ww = ww[0:1500, :]
        df2 = pd.DataFrame({
            'gt': ww[:, 0],
            'pred': ww[:, 1],
            'ed': ww[:, 2],
            'image': ww[:, 3]
        })

        html = df2.to_html(escape=False)
        report = open('{0}/ocr_bad.html'.format(dir_name), 'w')
        report.write(html)
        report.close()

        wworst = sorted(bad_words, key=lambda x: x[2], reverse=True)

        ww = np.asarray(wworst, np.object)
        ww = ww[0:1500, :]
        df2 = pd.DataFrame({
            'gt': ww[:, 0],
            'pred': ww[:, 1],
            'ed': ww[:, 2],
            'image': ww[:, 3]
        })

        html = df2.to_html(escape=False)
        report = open('{0}/ocr_not_sobad.html'.format(dir_name), 'w')
        report.write(html)
        report.close()

    return correct / float(images_count), ted
예제 #5
0
def main(opts):

  nclass = len(alphabet) + 1
  model_name = 'E2E-MLT'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))
  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  ### 第一种:只修改conv11的维度 
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     # 载入预训练模型
  #     print('loading pretrained model from %s' % opts.model)
  #     # pretrained_model = OwnModel(attention=True, nclass=7325)
  #     pretrained_model = ModelResNetSep2(attention=True, nclass=7500)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()
  #
  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'conv11' not in k and 'rnn' not in k}
  #     # 2. overwrite entries in the existing state dict
  #     model_dict.update(pretrained_dict)
  #     # 3. load the new state dict
  #     net.load_state_dict(model_dict)

  ### 第二种:直接接着前面训练
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  ### 
  
  step_start = 0
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=4, shuffle=True, collate_fn=E2Ecollate)
  
  train_loss = 0
  bbox_loss, seg_loss, angle_loss = 0., 0., 0.
  cnt = 0
  ctc_loss_val = 0
  ctc_loss_val2 = 0
  box_loss_val = 0
  gt_g_target = 0
  gt_g_proc = 0
  
  
  for step in range(step_start, opts.max_iters):

    loss = 0

    # batch
    images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(data_generator)
    im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=opts.cuda)
    # im_data = torch.from_numpy(images).type(torch.FloatTensor).permute(0, 3, 1, 2).cuda()           # permute(0,3,1,2)和cuda的先后顺序有影响
    start = timeit.timeit()
    try:
      seg_pred, roi_pred, angle_pred, features = net(im_data)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
    end = timeit.timeit()
    
    # for EAST loss
    smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
    training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=opts.cuda)
    angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4], is_cuda=opts.cuda)
    geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]], is_cuda=opts.cuda)
    
    try:
      loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred, angle_gt, roi_pred, geo_gt)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
      
    bbox_loss += net.box_loss_value.data.cpu().numpy() 
    seg_loss += net.segm_loss_value.data.cpu().numpy()
    angle_loss += net.angle_loss_value.data.cpu().numpy()  
    train_loss += loss.data.cpu().numpy()
    
       
    try:
      # 10000步之前都是用文字的标注区域训练的
      if step > 10000 or True: #this is just extra augumentation step ... in early stage just slows down training
        # ctcl, gt_target , gt_proc = process_boxes(images, im_data, seg_pred[0], roi_pred[0], angle_pred[0], score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug=opts.debug)
        ctcl= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
        gt_target = 1
        gt_proc = 1

        ctc_loss_val += ctcl.data.cpu().numpy()[0]
        loss = ctcl
        gt_g_target = gt_target
        gt_g_proc = gt_proc
        train_loss += ctcl.item()
      
      # -训练ocr识别部分的时候,采用一个data_generater生成
      # imageso, labels, label_length = next(dg_ocr)              # 其中应该有对倾斜文本的矫正
      # im_data_ocr = net_utils.np_to_variable(imageso, is_cuda=opts.cuda).permute(0, 3, 1, 2)
      # features = net.forward_features(im_data_ocr)
      # labels_pred = net.forward_ocr(features)
      # probs_sizes =  torch.IntTensor( [(labels_pred.permute(2,0,1).size()[0])] * (labels_pred.permute(2,0,1).size()[1]) )
      # label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int() )
      # labels = torch.IntTensor( torch.from_numpy(np.array(labels)).int() )
      # loss_ocr = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data_ocr.size(0) * 0.5
      # loss_ocr.backward()
      # ctc_loss_val2 += loss_ocr.item()

      net.zero_grad()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      pass


    cnt += 1
    if step % disp_interval == 0:
      if opts.debug:
        
        segm = seg_pred[0].data.cpu()[0].numpy()
        segm = segm.squeeze(0)
        cv2.imshow('segm_map', segm)
        
        segm_res = cv2.resize(score_maps[0], (images.shape[2], images.shape[1]))
        mask = np.argwhere(segm_res > 0)
        
        x_data = im_data.data.cpu().numpy()[0]
        x_data = x_data.swapaxes(0, 2)
        x_data = x_data.swapaxes(0, 1)
        
        x_data += 1
        x_data *= 128
        x_data = np.asarray(x_data, dtype=np.uint8)
        x_data = x_data[:, :, ::-1]
        
        im_show = x_data
        try:
          im_show[mask[:, 0], mask[:, 1], 1] = 255 
          im_show[mask[:, 0], mask[:, 1], 0] = 0 
          im_show[mask[:, 0], mask[:, 1], 2] = 0
        except:
          pass
        
        cv2.imshow('img0', im_show) 
        cv2.imshow('score_maps', score_maps[0] * 255)
        cv2.imshow('train_mask', training_masks[0] * 255)
        cv2.waitKey(10)
      
      train_loss /= cnt
      bbox_loss /= cnt
      seg_loss /= cnt
      angle_loss /= cnt
      ctc_loss_val /= cnt
      ctc_loss_val2 /= cnt
      box_loss_val /= cnt
      try:
        print('epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, gt_t/gt_proc:[%d/%d] lv2 %.3f' % (
          step / batch_per_epoch, step, train_loss, bbox_loss, seg_loss, angle_loss, ctc_loss_val, gt_g_target, gt_g_proc , ctc_loss_val2))
      except:
        import sys, traceback
        traceback.print_exc(file=sys.stdout)
        pass
    
      train_loss = 0
      bbox_loss, seg_loss, angle_loss = 0., 0., 0.
      cnt = 0
      ctc_loss_val = 0
      good_all = 0
      gt_all = 0
      box_loss_val = 0
      
    # for save mode
    #  validate(opts.valid_list, net)
    if step > step_start and (step % batch_per_epoch == 0):
      save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
      state = {'step': step,
               'learning_rate': learning_rate,
              'state_dict': net.state_dict(),
              'optimizer': optimizer.state_dict()}
      torch.save(state, save_name)
      print('save model: {}'.format(save_name))
예제 #6
0
  if args.cuda:
    print('Using cuda ...')
    net = net.cuda()

  imagelist = glob.glob(args.test_folder)
  with torch.no_grad():
    for path in imagelist:
      # path = '/home/yangna/deepblue/OCR/data/ICDAR2015/ch4_test_images/img_405.jpg'
      im = cv2.imread(path)

      im_resized, (ratio_h, ratio_w) = resize_image(im, scale_up=False)
      images = np.asarray([im_resized], dtype=np.float)
      images /= 128
      images -= 1
      im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=args.cuda)
      seg_pred, rboxs, angle_pred, features = net(im_data)

      rbox = rboxs[0].data.cpu()[0].numpy()                   # 转变成h,w,c
      rbox = rbox.swapaxes(0, 1)
      rbox = rbox.swapaxes(1, 2)

      angle_pred = angle_pred[0].data.cpu()[0].numpy()

      segm = seg_pred[0].data.cpu()[0].numpy()
      segm = segm.squeeze(0)

      draw2 = np.copy(im_resized)
      boxes =  get_boxes(segm, rbox, angle_pred, args.segm_thresh)

      img = Image.fromarray(draw2)
예제 #7
0
파일: eval.py 프로젝트: wisdal/NAVI-STR
            print(img_name)

            img = cv2.imread(img_name)

            #font = cv2.FONT_HERSHEY_SIMPLEX
            #cv2.putText(img,'cs',(10,img.shape[0] -40), font, 0.8,(255,255,255),2,cv2.LINE_AA)

            im_resized, (ratio_h, ratio_w) = resize_image(
                img, max_size=1848 * 1024,
                scale_up=True)  #1348*1024 #1848*1024
            #im_resized = im_resized[:, :, ::-1]
            images = np.asarray([im_resized], dtype=np.float)
            images /= 128
            images -= 1
            im_data = net_utils.np_to_variable(images,
                                               is_cuda=args.cuda).permute(
                                                   0, 3, 1, 2)

            [iou_pred, iou_pred1], rboxs, angle_pred, features = net(im_data)
            iou = iou_pred.data.cpu()[0].numpy()
            iou = iou.squeeze(0)

            iou_pred1 = iou_pred1.data.cpu()[0].numpy()
            iou_pred1 = iou_pred1.squeeze(0)

            #ioud = segm_predd.data.cpu()[0].numpy()
            #ioud = ioud.squeeze(0)

            rbox = rboxs[0].data.cpu()[0].numpy()
            rbox = rbox.swapaxes(0, 1)
            rbox = rbox.swapaxes(1, 2)
def main(opts):
    # pairs = c1, c2, label

    model_name = 'ICCV_OCR'
    net = OCRModel()

    if opts.cuda:
        net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net)
    else:
        learning_rate = base_lr
    print('train')
    net.train()

    # test(net)

    ctc_loss = CTCLoss(blank=0).cuda()

    data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
                                       batch_size=opts.batch_size,
                                       train_list=opts.train_list,
                                       in_train=True)

    train_loss = 0
    cnt = 0
    tq = tqdm(range(step_start, 10000000))
    for step in tq:

        # batch
        images, labels, label_length = next(data_generator)
        im_data = net_utils.np_to_variable(images,
                                           is_cuda=opts.cuda,
                                           volatile=False).permute(0, 3, 1, 2)
        labels_pred = net(im_data)

        # backward
        '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''
        torch.backends.cudnn.deterministic = True
        probs_sizes = Variable(
            torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] *
                            (labels_pred.permute(2, 0, 1).size()[1]))).long()
        label_sizes = Variable(
            torch.IntTensor(torch.from_numpy(
                np.array(label_length)).int())).long()
        labels = Variable(
            torch.IntTensor(torch.from_numpy(np.array(labels)).int())).long()
        optimizer.zero_grad()
        #probs = nn.functional.log_softmax(labels_pred, dim=94)

        labels_pred = labels_pred.permute(2, 0, 1)

        loss = ctc_loss(labels_pred, labels, probs_sizes,
                        label_sizes) / opts.batch_size  # change 1.9.
        if loss.item() == np.inf:
            continue
        #
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        cnt += 1
        # if step % disp_interval == 0:
        #     train_loss /= cnt
        #     print('epoch %d[%d], loss: %.3f, lr: %.5f ' % (
        #         step / batch_per_epoch, step, train_loss, learning_rate))
        #
        #     train_loss = 0
        #     cnt = 0
        tq.set_description(
            'epoch %d[%d], loss: %.3f, lr: %.5f ' %
            (step / batch_per_epoch, step, train_loss / cnt, learning_rate))
        #
        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))

            test(net)
def test(net, list_file='/home/liepieshov/dataset/en_words/test.csv'):
    net = net.eval()
    fout = open('./valid.txt', 'w')

    dir_name = os.path.dirname(list_file)
    images, bucket, label = ocr_gen.get_info_csv(list_file)

    it = 0
    correct = 0
    ted = 0
    gt_all = 0
    while True:

        imageNo = it

        if imageNo >= len(images):
            break

        image_name = images[imageNo]
        gt_txt = label[imageNo]

        img = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE)
        if img is None:
            print(image_name)
            continue
        if img.shape[0] > img.shape[1] * 2 and len(gt_txt) > 3:
            img = np.transpose(img)
            img = cv2.flip(img, flipCode=1)

        scaled = np.expand_dims(img, axis=2)
        scaled = np.expand_dims(scaled, axis=0)

        scaled = np.asarray(scaled, dtype=np.float)
        scaled /= 128
        scaled -= 1

        scaled_var = net_utils.np_to_variable(scaled,
                                              is_cuda=args.cuda,
                                              volatile=False).permute(
                                                  0, 3, 1, 2)
        ctc_f = net(scaled_var)
        ctc_f = ctc_f.data.cpu().numpy()
        ctc_f = ctc_f.swapaxes(1, 2)

        labels = ctc_f.argmax(2)
        det_text, conf, dec_s = print_seq_ext(labels[0, :])

        it += 1

        edit_dist = editdistance.eval(
            str(det_text).lower(),
            str(gt_txt).lower())
        ted += edit_dist
        gt_all += len(str(gt_txt))

        if str(det_text).lower() == str(gt_txt).lower():
            correct += 1
        else:
            print('{0} - {1} / {2:.2f} - {3:.2f}'.format(
                det_text, gt_txt, correct / float(it), ted / 3.0))

        fout.write('{0}|{1}|{2}|{3}\n'.format(os.path.basename(image_name),
                                              gt_txt, det_text, edit_dist))

    print('Test accuracy: {0:.3f}, {1:.2f}, {2:.3f}'.format(
        correct / float(it), ted / 3.0, ted / float(gt_all)))

    fout.close()
    net.train()
예제 #10
0
def main(opts):

  nclass = len(alphabet) + 1
  model_name = 'E2E-MLT'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))
  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  ### 第一种:只修改conv11的维度 
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     # 载入预训练模型
  #     print('loading pretrained model from %s' % opts.model)
  #     # pretrained_model = OwnModel(attention=True, nclass=7325)
  #     pretrained_model = ModelResNetSep2(attention=True, nclass=7500)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()
  #
  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'conv11' not in k and 'rnn' not in k}
  #     # 2. overwrite entries in the existing state dict
  #     model_dict.update(pretrained_dict)
  #     # 3. load the new state dict
  #     net.load_state_dict(model_dict)

  ### 第二种:直接接着前面训练
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  ### 
  
  step_start = 0
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=4, shuffle=True, collate_fn=E2Ecollate)
  
  train_loss = 0
  bbox_loss, seg_loss, angle_loss = 0., 0., 0.
  cnt = 0
  ctc_loss_val = 0
  ctc_loss_val2 = 0
  box_loss_val = 0
  gt_g_target = 0
  gt_g_proc = 0
  
  
  for step in range(step_start, opts.max_iters):

    loss = 0

    # batch
    images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(data_generator)
    im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=opts.cuda)
    # im_data = torch.from_numpy(images).type(torch.FloatTensor).permute(0, 3, 1, 2).cuda()       # permute(0,3,1,2)和cuda的先后顺序有影响
    start = timeit.timeit()
    try:
      seg_pred, roi_pred, angle_pred, features = net(im_data)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
    end = timeit.timeit()
    
    # for EAST loss
    smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
    training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=opts.cuda)
    angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4], is_cuda=opts.cuda)
    geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]], is_cuda=opts.cuda)
    
    try:
      loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred, angle_gt, roi_pred, geo_gt)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
      
    bbox_loss += net.box_loss_value.data.cpu().numpy() 
    seg_loss += net.segm_loss_value.data.cpu().numpy()
    angle_loss += net.angle_loss_value.data.cpu().numpy()  
    train_loss += loss.data.cpu().numpy()
    
       
    try:
      # 10000步之前都是用文字的标注区域训练的
      if step > 10000 or True: #this is just extra augumentation step ... in early stage just slows down training
    # ctcl, gt_target , gt_proc = process_boxes(images, im_data, seg_pred[0], roi_pred[0], angle_pred[0], score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug=opts.debug)
    ctcl= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
    gt_target = 1
    gt_proc = 1

    ctc_loss_val += ctcl.data.cpu().numpy()[0]
    loss = ctcl
    gt_g_target = gt_target
    gt_g_proc = gt_proc
    train_loss += ctcl.item()
      
      # -训练ocr识别部分的时候,采用一个data_generater生成
      # imageso, labels, label_length = next(dg_ocr)          # 其中应该有对倾斜文本的矫正
      # im_data_ocr = net_utils.np_to_variable(imageso, is_cuda=opts.cuda).permute(0, 3, 1, 2)
      # features = net.forward_features(im_data_ocr)
      # labels_pred = net.forward_ocr(features)
      # probs_sizes =  torch.IntTensor( [(labels_pred.permute(2,0,1).size()[0])] * (labels_pred.permute(2,0,1).size()[1]) )
      # label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int() )
      # labels = torch.IntTensor( torch.from_numpy(np.array(labels)).int() )
      # loss_ocr = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data_ocr.size(0) * 0.5
      # loss_ocr.backward()
      # ctc_loss_val2 += loss_ocr.item()

      net.zero_grad()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    except:
예제 #11
0
def main(opts):

    model_name = 'OctGatedMLT'
    net = OctMLT(attention=True)
    acc = []

    if opts.cuda:
        net.cuda()

    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(
            args.model,
            net,
            optimizer,
            load_ocr=opts.load_ocr,
            load_detection=opts.load_detection,
            load_shared=opts.load_shared,
            load_optimizer=opts.load_optimizer,
            reset_step=opts.load_reset_step)
    else:
        learning_rate = base_lr

    step_start = 0

    net.train()

    if opts.freeze_shared:
        net_utils.freeze_shared(net)

    if opts.freeze_ocr:
        net_utils.freeze_ocr(net)

    if opts.freeze_detection:
        net_utils.freeze_detection(net)

    #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
    #acc.append([0, acc_test])
    ctc_loss = CTCLoss()

    data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
                                       batch_size=opts.batch_size,
                                       train_list=opts.train_list,
                                       in_train=True,
                                       norm_height=opts.norm_height,
                                       rgb=True)

    train_loss = 0
    cnt = 0

    for step in range(step_start, 300000):
        # batch
        images, labels, label_length = next(data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        features = net.forward_features(im_data)
        labels_pred = net.forward_ocr(features)

        # backward
        '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''

        probs_sizes = torch.IntTensor(
            [(labels_pred.permute(2, 0, 1).size()[0])] *
            (labels_pred.permute(2, 0, 1).size()[1]))
        label_sizes = torch.IntTensor(
            torch.from_numpy(np.array(label_length)).int())
        labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
        loss = ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes,
                        label_sizes) / im_data.size(0)  # change 1.9.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not np.isinf(loss.data.cpu().numpy()):
            train_loss += loss.data.cpu().numpy()[0] if isinstance(
                loss.data.cpu().numpy(), list) else loss.data.cpu().numpy(
                )  #net.bbox_loss.data.cpu().numpy()[0]
            cnt += 1

        if opts.debug:
            dbg = labels_pred.data.cpu().numpy()
            ctc_f = dbg.swapaxes(1, 2)
            labels = ctc_f.argmax(2)
            det_text, conf, dec_s = print_seq_ext(labels[0, :], codec)

            print('{0} \t'.format(det_text))

        if step % disp_interval == 0:

            train_loss /= cnt
            print('epoch %d[%d], loss: %.3f, lr: %.5f ' %
                  (step / batch_per_epoch, step, train_loss, learning_rate))

            train_loss = 0
            cnt = 0

        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))

            #acc_test, ted = test(net, codec, opts,  list_file=opts.valid_list, norm_height=opts.norm_height)
            #acc.append([0, acc_test, ted])
            np.savez('train_acc_{0}'.format(model_name), acc=acc)
예제 #12
0
def main(opts):
    model_name = 'E2E-MLT'
    # net = ModelResNetSep2(attention=True)
    net = ModelResNetSep_crnn(
        attention=True,
        multi_scale=True,
        num_classes=400,
        fixed_height=norm_height,
        net='densenet',
    )
    # net = ModelResNetSep_final(attention=True)
    print("Using {0}".format(model_name))
    ctc_loss = nn.CTCLoss()
    if opts.cuda:
        net.to(device)
        ctc_loss.to(device)
    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max', factor=0.5, patience=4, verbose=True)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                                  base_lr=0.0006,
                                                  max_lr=0.001,
                                                  step_size_up=3000,
                                                  cycle_momentum=False)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        # net_dict = net.state_dict()
        step_start, learning_rate = net_utils.load_net(args.model, net,
                                                       optimizer)
    #     step_start, learning_rate = net_utils.load_net(args.model, net, None)
    #
    #   step_start = 0
    net_utils.adjust_learning_rate(optimizer, learning_rate)

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_path,
                                        geo_type=opts.geo_type,
                                        normalize=opts.normalize)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True,
                               normalize=opts.normalize)

    # e2edata = E2Edataset(train_list=opts.eval_path, normalize= opts.normalize)
    # e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate
    #                                           )

    train_loss = 0
    train_loss_temp = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 1

    # ctc_loss = CTCLoss()

    ctc_loss_val = 0
    ctc_loss_val2 = 0
    ctcl = torch.tensor([0])
    box_loss_val = 0
    good_all = 0
    gt_all = 0
    train_loss_lr = 0
    cntt = 0
    time_total = 0
    now = time.time()

    for step in range(step_start, opts.max_iters):
        # scheduler.batch_step()

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        # cv2.imshow('img', images)
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            # ? loss
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

            # @ loss_val
        if not (torch.isnan(loss) or torch.isinf(loss)):
            train_loss_temp += loss.data.cpu().numpy()

        optimizer.zero_grad()

        try:

            if step > 1000 or True:  # this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)

                # ? loss
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            # features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(im_data_ocr)

            probs_sizes = torch.IntTensor([(labels_pred.size()[0])] *
                                          (labels_pred.size()[1])).long()
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int()).long()
            labels = torch.IntTensor(torch.from_numpy(
                np.array(labels)).int()).long()
            loss_ocr = ctc_loss(labels_pred, labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            # @ loss_val
            # ctc_loss_val2 += loss_ocr.item()

            loss.backward()

            clipping_value = 0.5
            torch.nn.utils.clip_grad_norm_(net.parameters(), clipping_value)
            if opts.d1:
                print('loss_nan', torch.isnan(loss))
                print('loss_inf', torch.isinf(loss))
                print('lossocr_nan', torch.isnan(loss_ocr))
                print('lossocr_inf', torch.isinf(loss_ocr))

            if not (torch.isnan(loss) or torch.isinf(loss)
                    or torch.isnan(loss_ocr) or torch.isinf(loss_ocr)):
                bbox_loss += net.box_loss_value.data.cpu().numpy()
                seg_loss += net.segm_loss_value.data.cpu().numpy()
                angle_loss += net.angle_loss_value.data.cpu().numpy()
                train_loss += train_loss_temp
                ctc_loss_val2 += loss_ocr.item()
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                # train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
                optimizer.step()
                scheduler.step()
                train_loss_temp = 0
                cnt += 1

        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass

        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                if opts.normalize:
                    x_data += 1
                    x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            ctc_loss_val2 /= cnt
            box_loss_val /= cnt
            train_loss_lr += (train_loss)

            cntt += 1
            time_now = time.time() - now
            time_total += time_now
            now = time.time()
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
            save_log = os.path.join(opts.save_path, 'loss.txt')

            f = open(save_log, 'a')
            f.write(
                'epoch %d[%d], lr: %f, loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f, lv2: %.3f, time: %.2f s, cnt: %d\n'
                % (step / batch_per_epoch, step, learning_rate, train_loss,
                   bbox_loss, seg_loss, angle_loss, ctc_loss_val,
                   good_all / max(1, gt_all), ctc_loss_val2, time_now, cnt))
            f.close()
            try:

                print(
                    'epoch %d[%d], lr: %f, loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f, lv2: %.3f, time: %.2f s, cnt: %d\n'
                    %
                    (step / batch_per_epoch, step, learning_rate, train_loss,
                     bbox_loss, seg_loss, angle_loss, ctc_loss_val,
                     good_all / max(1, gt_all), ctc_loss_val2, time_now, cnt))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            ctc_loss_val2 = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        # if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
                print('learning_rate', learning_rate)
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            #evaluate
            re_tpe2e, re_tp, re_e1, precision = evaluate_e2e_crnn(
                root=args.eval_path,
                net=net,
                norm_height=norm_height,
                name_model=save_name,
                normalize=args.normalize,
                save_dir=args.save_path)
            # CER,WER = evaluate_crnn(e2edataloader,net)

            # scheduler.step(re_tpe2e)
            f = open(save_log, 'a')
            f.write(
                'time epoch [%d]: %.2f s, loss_total: %.3f, lr:%f, re_tpe2e = %f, re_tp = %f, re_e1 = %f, precision = %f\n'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   learning_rate, re_tpe2e, re_tp, re_e1, precision))
            f.close()
            print(
                'time epoch [%d]: %.2f s, loss_total: %.3f, re_tpe2e = %f, re_tp = %f, re_e1 = %f, precision = %f'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   re_tpe2e, re_tp, re_e1, precision))
            #print('time epoch [%d]: %.2f s, loss_total: %.3f' % (step / batch_per_epoch, time_total,train_loss_lr/cntt))
            print('save model: {}'.format(save_name))
            time_total = 0
            cntt = 0
            train_loss_lr = 0
            net.train()
예제 #13
0
파일: demo.py 프로젝트: dipikakhullar/ocr
def run_model_input_image(im, show_boxes=False):
  predictions = {}
  parser = argparse.ArgumentParser()
  parser.add_argument('-cuda', type=int, default=1)
  parser.add_argument('-model', default='e2e-mlt-rctw.h5')
  parser.add_argument('-segm_thresh', default=0.5)

  font2 = ImageFont.truetype("Arial-Unicode-Regular.ttf", 18)

  args = parser.parse_args()

  net = ModelResNetSep2(attention=True)
  net_utils.load_net(args.model, net)
  net = net.eval()

  if args.cuda:
    print('Using cuda ...')
    net = net.cuda()

  with torch.no_grad():
    # im = Image.open(im)
    # im = im.convert('RGB')
    im = np.asarray(im)
    im = im[...,:3]
    im_resized, (ratio_h, ratio_w) = resize_image(im, scale_up=False)
    images = np.asarray([im_resized], dtype=np.float)
    images /= 128
    images -= 1
    im_data = net_utils.np_to_variable(images, is_cuda=args.cuda).permute(0, 3, 1, 2)
    seg_pred, rboxs, angle_pred, features = net(im_data)

    rbox = rboxs[0].data.cpu()[0].numpy()
    rbox = rbox.swapaxes(0, 1)
    rbox = rbox.swapaxes(1, 2)

    angle_pred = angle_pred[0].data.cpu()[0].numpy()


    segm = seg_pred[0].data.cpu()[0].numpy()
    segm = segm.squeeze(0)

    draw2 = np.copy(im_resized)
    boxes =  get_boxes(segm, rbox, angle_pred, args.segm_thresh)

    img = Image.fromarray(draw2)
    draw = ImageDraw.Draw(img)

    #if len(boxes) > 10:
    #  boxes = boxes[0:10]

    out_boxes = []
    prediction_i = []
    for box in boxes:

        pts  = box[0:8]
        pts = pts.reshape(4, -1)

        det_text, conf, dec_s = ocr_image(net, codec, im_data, box)
        if len(det_text) == 0:
            continue

        width, height = draw.textsize(det_text, font=font2)
        center =  [box[0], box[1]]
        draw.text((center[0], center[1]), det_text, fill = (0,255,0),font=font2)
        out_boxes.append(box)

        # det_text is one prediction
        prediction_i.append(det_text.lower())

    predictions["frame"] = prediction_i

    # show each image boxes and output in pop up window.
    show_image_with_boxes(img, out_boxes, show=show_boxes)

  print(predictions)
  return predictions
예제 #14
0
def evaluate_e2e_crnn(root,
                      net,
                      norm_height=48,
                      name_model='E2E',
                      normalize=False,
                      save=False,
                      cuda=True,
                      save_dir='eval'):
    #Decription : evaluate model E2E
    net = net.eval()
    # if cuda:
    #   print('Using cuda ...')
    #   net = net.to(device)

    images = glob.glob(os.path.join(root, '*.jpg'))
    png = glob.glob(os.path.join(root, '*.png'))
    images.extend(png)
    png = glob.glob(os.path.join(root, '*.JPG'))
    images.extend(png)

    imagess = np.asarray(images)

    tp_all = 0
    gt_all = 0
    tp_e2e_all = 0
    gt_e2e_all = 0
    tp_e2e_ed1_all = 0
    detecitons_all = 0
    eval_text_length = 2
    segm_thresh = 0.5
    min_height = 8
    idx = 0

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    note_path = os.path.join(save_dir, 'note_eval.txt')
    note_file = open(note_path, 'a')

    with torch.no_grad():

        index = np.arange(0, imagess.shape[0])
        # np.random.shuffle(index)
        for i in index:
            img_name = imagess[i]
            base_nam = os.path.basename(img_name)
            #
            # if args.evaluate == 1:
            res_gt = base_nam.replace(".jpg", '.txt').replace(".png", '.txt')
            res_gt = '{0}/gt_{1}'.format(root, res_gt)
            if not os.path.exists(res_gt):
                res_gt = base_nam.replace(".jpg", '.txt').replace("_", "")
                res_gt = '{0}/gt_{1}'.format(root, res_gt)
                if not os.path.exists(res_gt):
                    print('missing! {0}'.format(res_gt))
                    gt_rect, gt_txts = [], []
            # continue
            gt_rect, gt_txts = load_gt(res_gt)

            # print(img_name)
            img = cv2.imread(img_name)

            im_resized, _ = resize_image(
                img, max_size=1848 * 1024,
                scale_up=False)  # 1348*1024 #1848*1024
            images = np.asarray([im_resized], dtype=np.float)

            if normalize:
                images /= 128
                images -= 1
            im_data = net_utils.np_to_variable(images, is_cuda=cuda).permute(
                0, 3, 1, 2)

            [iou_pred, iou_pred1], rboxs, angle_pred, features = net(im_data)
            iou = iou_pred.data.cpu()[0].numpy()
            iou = iou.squeeze(0)

            rbox = rboxs[0].data.cpu()[0].numpy()
            rbox = rbox.swapaxes(0, 1)
            rbox = rbox.swapaxes(1, 2)

            detections = get_boxes(iou, rbox,
                                   angle_pred[0].data.cpu()[0].numpy(),
                                   segm_thresh)

            im_scalex = im_resized.shape[1] / img.shape[1]
            im_scaley = im_resized.shape[0] / img.shape[0]

            detetcions_out = []
            detectionso = np.copy(detections)
            if len(detections) > 0:
                detections[:, 0] /= im_scalex
                detections[:, 2] /= im_scalex
                detections[:, 4] /= im_scalex
                detections[:, 6] /= im_scalex

                detections[:, 1] /= im_scaley
                detections[:, 3] /= im_scaley
                detections[:, 5] /= im_scaley
                detections[:, 7] /= im_scaley

            for bid, box in enumerate(detections):

                boxo = detectionso[bid]
                # score = boxo[8]
                boxr = boxo[0:8].reshape(-1, 2)
                # box_area = area(boxr.reshape(8))

                # conf_factor = score / box_area

                center = (boxr[0, :] + boxr[1, :] + boxr[2, :] +
                          boxr[3, :]) / 4

                dw = boxr[2, :] - boxr[1, :]
                dw2 = boxr[0, :] - boxr[3, :]
                dh = boxr[1, :] - boxr[0, :]
                dh2 = boxr[3, :] - boxr[2, :]

                h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) + 1
                h2 = math.sqrt(dh2[0] * dh2[0] + dh2[1] * dh2[1]) + 1
                h = (h + h2) / 2
                w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
                w2 = math.sqrt(dw2[0] * dw2[0] + dw2[1] * dw2[1])
                w = (w + w2) / 2

                if ((h - 1) / im_scaley) < min_height:
                    continue

                input_W = im_data.size(3)
                input_H = im_data.size(2)
                target_h = norm_height

                scale = target_h / h
                target_gw = int(w * scale + target_h / 4)
                target_gw = max(8, int(round(target_gw / 8)) * 8)
                xc = center[0]
                yc = center[1]
                w2 = w
                h2 = h

                angle = math.atan2((boxr[2][1] - boxr[1][1]),
                                   boxr[2][0] - boxr[1][0])
                angle2 = math.atan2((boxr[3][1] - boxr[0][1]),
                                    boxr[3][0] - boxr[0][0])
                angle = (angle + angle2) / 2

                # show pooled image in image layer
                scalex = (w2 + h2 / 4) / input_W
                scaley = h2 / input_H

                th11 = scalex * math.cos(angle)
                th12 = -math.sin(angle) * scaley * input_H / input_W
                th13 = (2 * xc - input_W - 1) / (input_W - 1)

                th21 = math.sin(angle) * scalex * input_W / input_H
                th22 = scaley * math.cos(angle)
                th23 = (2 * yc - input_H - 1) / (input_H - 1)

                t = np.asarray([th11, th12, th13, th21, th22, th23],
                               dtype=np.float)
                t = torch.from_numpy(t).type(torch.FloatTensor)
                t = t.to(device)
                theta = t.view(-1, 2, 3)

                grid = F.affine_grid(
                    theta, torch.Size((1, 3, int(target_h), int(target_gw))))
                x = F.grid_sample(im_data, grid)

                # features = net.forward_features(x)
                # labels_pred = net.forward_ocr(features)
                labels_pred = net.forward_ocr(x)
                labels_pred = labels_pred.permute(1, 2, 0)

                ctc_f = labels_pred.data.cpu().numpy()
                ctc_f = ctc_f.swapaxes(1, 2)

                labels = ctc_f.argmax(2)

                conf = np.mean(np.exp(ctc_f.max(2)[labels > 3]))
                if conf < 0.02:
                    continue

                det_text, conf2, dec_s, word_splits = print_seq_ext(
                    labels[0, :], codec)
                det_text = det_text.strip()

                if conf < 0.01 and len(det_text) == 3:
                    continue

                if len(det_text) > 0:
                    dtxt = det_text.strip()
                    if len(dtxt) >= eval_text_length:
                        # print('{0} - {1}'.format(dtxt, conf_factor))
                        boxw = np.copy(boxr)
                        boxw[:, 1] /= im_scaley
                        boxw[:, 0] /= im_scalex
                        boxw = boxw.reshape(8)

                        detetcions_out.append([boxw, dtxt])

            pix = img

            # if args.evaluate == 1:
            tp, tp_e2e, gt_e2e, tp_e2e_ed1, detection_to_gt, pixx = evaluate_image(
                pix,
                detetcions_out,
                gt_rect,
                gt_txts,
                eval_text_length=eval_text_length)
            tp_all += tp
            gt_all += len(gt_txts)
            tp_e2e_all += tp_e2e
            gt_e2e_all += gt_e2e
            tp_e2e_ed1_all += tp_e2e_ed1
            detecitons_all += len(detetcions_out)
            # print(gt_all)
            if save:
                cv2.imwrite('{0}/{1}'.format(save_dir, base_nam), pixx)

            # print("	E2E recall tp_e2e:{0:.3f} / tp:{1:.3f} / e1:{2:.3f}, precision: {3:.3f}".format(
            #   tp_e2e_all / float(max(1, gt_e2e_all)),
            #   tp_all / float(max(1, gt_e2e_all)),
            #   tp_e2e_ed1_all / float(max(1, gt_e2e_all)),
            #   tp_all / float(max(1, detecitons_all))))

        note_file.write(
            'Model{4}---E2E recall tp_e2e:{0:.3f} / tp:{1:.3f} / e1:{2:.3f}, precision: {3:.3f} \n'
            .format(tp_e2e_all / float(max(1, gt_e2e_all)),
                    tp_all / float(max(1, gt_e2e_all)),
                    tp_e2e_ed1_all / float(max(1, gt_e2e_all)),
                    tp_all / float(max(1, detecitons_all)), name_model))

        note_file.close()
    return (tp_e2e_all / float(max(1, gt_e2e_all)),
            tp_all / float(max(1, gt_e2e_all)),
            tp_e2e_ed1_all / float(max(1, gt_e2e_all)),
            tp_all / float(max(1, detecitons_all)))
예제 #15
0
def main(opts):
  
  model_name = 'E2E-MLT'
  net = ModelResNetSep_final(attention=True)
  acc = []
  ctc_loss = nn.CTCLoss()
  if opts.cuda:
    net.cuda()
    ctc_loss.cuda()
  optimizer = torch.optim.Adam(net.parameters(), lr=base_lr, weight_decay=weight_decay)
  scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0005, max_lr=0.001, step_size_up=3000,
                                                cycle_momentum=False)
  step_start = 0  
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  else:
    learning_rate = base_lr

  for param_group in optimizer.param_groups:
    param_group['lr'] = base_lr
    learning_rate = param_group['lr']
    print(param_group['lr'])
  
  step_start = 0  

  net.train()
  
  #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
  #acc.append([0, acc_test])
    
  # ctc_loss = CTCLoss()
  ctc_loss = nn.CTCLoss()

  data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
          batch_size=opts.batch_size, 
          train_list=opts.train_list, in_train=True, norm_height=opts.norm_height, rgb = True, normalize= True)
  
  val_dataset = ocrDataset(root=opts.valid_list, norm_height=opts.norm_height , in_train=False,is_crnn=False)
  val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False,
                                                collate_fn=alignCollate())


  # val_generator1 = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=False,
  #                                              collate_fn=alignCollate())

  cnt = 1
  cntt = 0
  train_loss_lr = 0
  time_total = 0
  train_loss = 0
  now = time.time()

  for step in range(step_start, 300000):
    # batch
    images, labels, label_length = next(data_generator)
    im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(0, 3, 1, 2)
    features = net.forward_features(im_data)
    labels_pred = net.forward_ocr(features)
    
    # backward
    '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''
    
    probs_sizes =  torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] * (labels_pred.permute(2, 0, 1).size()[1])).long()
    label_sizes = torch.IntTensor(torch.from_numpy(np.array(label_length)).int()).long()
    labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int()).long()
    loss = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data.size(0) # change 1.9.
    optimizer.zero_grad()
    loss.backward()

    clipping_value = 0.5
    torch.nn.utils.clip_grad_norm_(net.parameters(),clipping_value)
    if not (torch.isnan(loss) or torch.isinf(loss)):
      optimizer.step()
      scheduler.step()
    # if not np.isinf(loss.data.cpu().numpy()):
      train_loss += loss.data.cpu().numpy() #net.bbox_loss.data.cpu().numpy()[0]
      # train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
      cnt += 1
    
    if opts.debug:
      dbg = labels_pred.data.cpu().numpy()
      ctc_f = dbg.swapaxes(1, 2)
      labels = ctc_f.argmax(2)
      det_text, conf, dec_s,_ = print_seq_ext(labels[0, :], codec)
      
      print('{0} \t'.format(det_text))
    
    
    
    if step % disp_interval == 0:
      for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']
        
      train_loss /= cnt
      train_loss_lr += train_loss
      cntt += 1
      time_now = time.time() - now
      time_total += time_now
      now = time.time()
      save_log = os.path.join(opts.save_path, 'loss_ocr.txt')
      f = open(save_log, 'a')
      f.write(
        'epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n' % (
          step / batch_per_epoch, step, train_loss, time_now,learning_rate, cnt))
      f.close()

      print('epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n' % (
          step / batch_per_epoch, step, train_loss, time_now,learning_rate, cnt))

      train_loss = 0
      cnt = 1

    if step > step_start and (step % batch_per_epoch == 0):
      CER, WER = eval_ocr(val_generator, net)
      net.train()
      for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']
        # print(learning_rate)

      save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
      state = {'step': step,
               'learning_rate': learning_rate,
              'state_dict': net.state_dict(),
              'optimizer': optimizer.state_dict()}
      torch.save(state, save_name)
      print('save model: {}'.format(save_name))
      save_logg = os.path.join(opts.save_path, 'note_eval.txt')
      fe = open(save_logg, 'a')
      fe.write('time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f\n' % (
      step / batch_per_epoch, time_total, train_loss_lr / cntt, CER, WER))
      fe.close()
      print('time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f' % (
      step / batch_per_epoch, time_total, train_loss_lr / cntt, CER, WER))
      time_total = 0
      cntt = 0
      train_loss_lr = 0