def process_boxes(images,
                  im_data,
                  iou_pred,
                  roi_pred,
                  angle_pred,
                  score_maps,
                  gt_idxs,
                  gtso,
                  lbso,
                  features,
                  net,
                  ctc_loss,
                  opts,
                  debug=False):

    ctc_loss_count = 0
    loss = torch.from_numpy(np.asarray([0])).type(torch.FloatTensor).cuda()

    for bid in range(iou_pred.size(0)):

        gts = gtso[bid]
        lbs = lbso[bid]

        gt_proc = 0
        gt_good = 0

        gts_count = {}

        iou_pred_np = iou_pred[bid].data.cpu().numpy()
        iou_map = score_maps[bid]
        to_walk = iou_pred_np.squeeze(0) * iou_map * (iou_pred_np.squeeze(0) >
                                                      0.5)

        roi_p_bid = roi_pred[bid].data.cpu().numpy()
        gt_idx = gt_idxs[bid]

        if debug:
            img = images[bid]
            img += 1
            img *= 128
            img = np.asarray(img, dtype=np.uint8)

        xy_text = np.argwhere(to_walk > 0)
        random.shuffle(xy_text)
        xy_text = xy_text[0:min(xy_text.shape[0], 100)]

        for i in range(0, xy_text.shape[0]):
            if opts.geo_type == 1:
                break
            pos = xy_text[i, :]

            gt_id = gt_idx[pos[0], pos[1]]

            if not gt_id in gts_count:
                gts_count[gt_id] = 0

            if gts_count[gt_id] > 2:
                continue

            gt = gts[gt_id]
            gt_txt = lbs[gt_id]
            if gt_txt.startswith('##'):
                continue

            angle_sin = angle_pred[bid, 0, pos[0], pos[1]]
            angle_cos = angle_pred[bid, 1, pos[0], pos[1]]

            angle = math.atan2(angle_sin, angle_cos)

            angle_gt = (math.atan2(
                (gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2(
                    (gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0])) / 2

            if math.fabs(angle_gt - angle) > math.pi / 16:
                continue

            offset = roi_p_bid[:, pos[0], pos[1]]
            posp = pos + 0.25
            pos_g = np.array([(posp[1] - offset[0] * math.sin(angle)) * 4,
                              (posp[0] - offset[0] * math.cos(angle)) * 4])
            pos_g2 = np.array([(posp[1] + offset[1] * math.sin(angle)) * 4,
                               (posp[0] + offset[1] * math.cos(angle)) * 4])

            pos_r = np.array([(posp[1] - offset[2] * math.cos(angle)) * 4,
                              (posp[0] - offset[2] * math.sin(angle)) * 4])
            pos_r2 = np.array([(posp[1] + offset[3] * math.cos(angle)) * 4,
                               (posp[0] + offset[3] * math.sin(angle)) * 4])

            center = (pos_g + pos_g2 + pos_r + pos_r2) / 2 - [
                4 * pos[1], 4 * pos[0]
            ]
            #center = (pos_g + pos_g2 + pos_r + pos_r2) / 4
            dw = pos_r - pos_r2
            dh = pos_g - pos_g2

            w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
            h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1])

            dhgt = gt[1] - gt[0]

            h_gt = math.sqrt(dhgt[0] * dhgt[0] + dhgt[1] * dhgt[1])
            if h_gt < 10:
                continue

            rect = ((center[0], center[1]), (w, h), angle * 180 / math.pi)
            pts = cv2.boxPoints(rect)

            pred_bbox = cv2.boundingRect(pts)
            pred_bbox = [
                pred_bbox[0], pred_bbox[1], pred_bbox[2], pred_bbox[3]
            ]
            pred_bbox[2] += pred_bbox[0]
            pred_bbox[3] += pred_bbox[1]

            if gt[:,
                  0].max() > im_data.size(3) or gt[:,
                                                   1].max() > im_data.size(3):
                continue

            gt_bbox = [
                gt[:, 0].min(), gt[:, 1].min(), gt[:, 0].max(), gt[:, 1].max()
            ]
            inter = intersect(pred_bbox, gt_bbox)

            uni = union(pred_bbox, gt_bbox)
            ratio = area(inter) / float(area(uni))

            if ratio < 0.90:
                continue

            hratio = min(h, h_gt) / max(h, h_gt)
            if hratio < 0.5:
                continue

            input_W = im_data.size(3)
            input_H = im_data.size(2)
            target_h = norm_height

            scale = target_h / h
            target_gw = (int(w * scale) + target_h)
            target_gw = max(8, int(round(target_gw / 4)) * 4)

            #show pooled image in image layer

            scalex = (w + h) / input_W
            scaley = h / input_H

            th11 = scalex * math.cos(angle)
            th12 = -math.sin(angle) * scaley
            th13 = (2 * center[0] - input_W - 1) / (
                input_W - 1
            )  #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var)

            th21 = math.sin(angle) * scalex
            th22 = scaley * math.cos(angle)
            th23 = (2 * center[1] - input_H - 1) / (
                input_H - 1
            )  #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var)

            t = np.asarray([th11, th12, th13, th21, th22, th23],
                           dtype=np.float)
            t = torch.from_numpy(t).type(torch.FloatTensor).cuda()

            #t = torch.stack((th11, th12, th13, th21, th22, th23), dim=1)
            theta = t.view(-1, 2, 3)

            grid = F.affine_grid(
                theta, torch.Size((1, 3, int(target_h), int(target_gw))))

            x = F.grid_sample(im_data[bid].unsqueeze(0), grid)

            if debug:
                x_c = x.data.cpu().numpy()[0]
                x_data_draw = x_c.swapaxes(0, 2)
                x_data_draw = x_data_draw.swapaxes(0, 1)

                x_data_draw += 1
                x_data_draw *= 128
                x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
                x_data_draw = x_data_draw[:, :, ::-1]

                cv2.circle(img, (int(center[0]), int(center[1])), 5,
                           (0, 255, 0))
                cv2.imshow('im_data', x_data_draw)

                draw_box_points(img, pts)
                draw_box_points(img, gt, color=(0, 0, 255))

                cv2.imshow('img', img)
                cv2.waitKey(100)

            gt_labels = []
            gt_labels.append(codec_rev[' '])
            for k in range(len(gt_txt)):
                if gt_txt[k] in codec_rev:
                    gt_labels.append(codec_rev[gt_txt[k]])
                else:
                    print('Unknown char: {0}'.format(gt_txt[k]))
                    gt_labels.append(3)

            if 'ARABIC' in ud.name(gt_txt[0]):
                gt_labels = gt_labels[::-1]
            gt_labels.append(codec_rev[' '])

            features = net.forward_features(x)
            labels_pred = net.forward_ocr(features)

            label_length = []
            label_length.append(len(gt_labels))
            probs_sizes = autograd.Variable(
                torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] *
                                (labels_pred.permute(2, 0, 1).size()[1])))
            label_sizes = autograd.Variable(
                torch.IntTensor(
                    torch.from_numpy(np.array(label_length)).int()))
            labels = autograd.Variable(
                torch.IntTensor(torch.from_numpy(np.array(gt_labels)).int()))

            loss = loss + ctc_loss(labels_pred.permute(2, 0, 1), labels,
                                   probs_sizes, label_sizes).cuda()
            ctc_loss_count += 1

            if debug:
                ctc_f = labels_pred.data.cpu().numpy()
                ctc_f = ctc_f.swapaxes(1, 2)

                labels = ctc_f.argmax(2)
                det_text, conf, dec_s, splits = print_seq_ext(
                    labels[0, :], codec)

                print('{0} \t {1}'.format(det_text, gt_txt))

            gts_count[gt_id] += 1

            if ctc_loss_count > 64 or debug:
                break

        for gt_id in range(0, len(gts)):

            gt = gts[gt_id]
            gt_txt = lbs[gt_id]

            gt_txt_low = gt_txt.lower()
            if gt_txt.startswith('##'):
                continue

            if gt[:,
                  0].max() > im_data.size(3) or gt[:,
                                                   1].max() > im_data.size(3):
                continue

            if gt.min() < 0:
                continue

            center = (gt[0, :] + gt[1, :] + gt[2, :] + gt[3, :]) / 4
            dw = gt[2, :] - gt[1, :]
            dh = gt[1, :] - gt[0, :]

            w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
            h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) + random.randint(
                -2, 2)

            if h < 8:
                #print('too small h!')
                continue

            angle_gt = (math.atan2(
                (gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2(
                    (gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0])) / 2

            input_W = im_data.size(3)
            input_H = im_data.size(2)
            target_h = norm_height

            scale = target_h / h
            target_gw = int(w * scale) + random.randint(0, int(target_h))
            target_gw = max(8, int(round(target_gw / 4)) * 4)

            xc = center[0]
            yc = center[1]
            w2 = w
            h2 = h

            #show pooled image in image layer

            scalex = (w2 + random.randint(0, int(h2))) / input_W
            scaley = h2 / input_H

            th11 = scalex * math.cos(angle_gt)
            th12 = -math.sin(angle_gt) * scaley
            th13 = (2 * xc - input_W - 1) / (
                input_W - 1
            )  #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var)

            th21 = math.sin(angle_gt) * scalex
            th22 = scaley * math.cos(angle_gt)
            th23 = (2 * yc - input_H - 1) / (
                input_H - 1
            )  #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var)

            t = np.asarray([th11, th12, th13, th21, th22, th23],
                           dtype=np.float)
            t = torch.from_numpy(t).type(torch.FloatTensor)
            t = t.cuda()
            theta = t.view(-1, 2, 3)

            grid = F.affine_grid(
                theta, torch.Size((1, 3, int(target_h), int(target_gw))))
            x = F.grid_sample(im_data[bid].unsqueeze(0), grid)

            #score_sampled = F.grid_sample(iou_pred[bid].unsqueeze(0), grid)

            gt_labels = []
            gt_labels.append(codec_rev[' '])
            for k in range(len(gt_txt)):
                if gt_txt[k] in codec_rev:
                    gt_labels.append(codec_rev[gt_txt[k]])
                else:
                    print('Unknown char: {0}'.format(gt_txt[k]))
                    gt_labels.append(3)
            gt_labels.append(codec_rev[' '])

            if 'ARABIC' in ud.name(gt_txt[0]):
                gt_labels = gt_labels[::-1]

            features = net.forward_features(x)
            labels_pred = net.forward_ocr(features)

            label_length = []
            label_length.append(len(gt_labels))
            probs_sizes = torch.IntTensor(
                [(labels_pred.permute(2, 0, 1).size()[0])] *
                (labels_pred.permute(2, 0, 1).size()[1]))
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int())
            labels = torch.IntTensor(
                torch.from_numpy(np.array(gt_labels)).int())

            loss = loss + ctc_loss(labels_pred.permute(2, 0, 1), labels,
                                   probs_sizes, label_sizes).cuda()
            ctc_loss_count += 1

            if debug:
                x_d = x.data.cpu().numpy()[0]
                x_data_draw = x_d.swapaxes(0, 2)
                x_data_draw = x_data_draw.swapaxes(0, 1)

                x_data_draw += 1
                x_data_draw *= 128
                x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
                x_data_draw = x_data_draw[:, :, ::-1]
                cv2.imshow('im_data_gt', x_data_draw)
                cv2.waitKey(100)

            gt_proc += 1
            if True:
                ctc_f = labels_pred.data.cpu().numpy()
                ctc_f = ctc_f.swapaxes(1, 2)

                labels = ctc_f.argmax(2)
                det_text, conf, dec_s, splits = print_seq_ext(
                    labels[0, :], codec)
                if debug:
                    print('{0} \t {1}'.format(det_text, gt_txt))
                if det_text.lower() == gt_txt.lower():
                    gt_good += 1

            if ctc_loss_count > 128 or debug:
                break

    if ctc_loss_count > 0:
        loss /= ctc_loss_count

    return loss, gt_good, gt_proc
Beispiel #2
0
def main(opts):

    train_loss = 0
    train_loss_lr = 0
    cnt = 1
    cntt = 0
    time_total = 0
    now = time.time()
    converter = strLabelConverter(codec)

    model_name = 'E2E-MLT'
    net = ModelResNetSep_crnn(
        attention=True,
        multi_scale=True,
        num_classes=400,
        fixed_height=opts.norm_height,
        net='densenet',
    )
    ctc_loss = nn.CTCLoss()
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=base_lr,
                                 weight_decay=weight_decay)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.5 ,patience=5,verbose=True)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                                  base_lr=0.00007,
                                                  max_lr=0.0003,
                                                  step_size_up=3000,
                                                  cycle_momentum=False)
    step_start = 0
    if opts.cuda:
        net.to(device)
        ctc_loss.to(device)
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net,
                                                       optimizer)
    else:
        learning_rate = base_lr

    for param_group in optimizer.param_groups:
        param_group['lr'] = base_lr
        learning_rate = param_group['lr']
        print(param_group['lr'])

    step_start = 0

    net.train()

    # data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
    #                                    batch_size=opts.batch_size,
    #                                    train_list=opts.train_list, in_train=True, norm_height=opts.norm_height, rgb = True)

    data_dataset = ocrDataset(root=opts.train_list,
                              norm_height=opts.norm_height,
                              in_train=True)
    data_generator1 = torch.utils.data.DataLoader(data_dataset,
                                                  batch_size=opts.batch_size,
                                                  shuffle=True,
                                                  collate_fn=alignCollate())
    val_dataset = ocrDataset(root=opts.valid_list,
                             norm_height=opts.norm_height,
                             in_train=False)
    val_generator1 = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 collate_fn=alignCollate())

    for step in range(step_start, 300000):
        # images, labels, label_length = next(data_generator)
        # im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(0, 3, 1, 2)

        try:
            images, label = next(dataloader_iterator)
        except:
            dataloader_iterator = iter(data_generator1)
            images, label = next(dataloader_iterator)
        labels, label_length = converter.encode(label)
        im_data = images.to(device)
        labels_pred = net.forward_ocr(im_data)

        # backward
        probs_sizes = torch.IntTensor([(labels_pred.size()[0])] *
                                      (labels_pred.size()[1]))
        label_sizes = torch.IntTensor(
            torch.from_numpy(np.array(label_length)).int())
        labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
        loss = ctc_loss(labels_pred, labels, probs_sizes,
                        label_sizes) / im_data.size(0)  # change 1.9.
        optimizer.zero_grad()
        loss.backward()

        clipping_value = 1.0
        torch.nn.utils.clip_grad_norm_(net.parameters(), clipping_value)
        if not (torch.isnan(loss) or torch.isinf(loss)):
            optimizer.step()
            scheduler.step()
            train_loss += loss.data.cpu().numpy(
            )  #net.bbox_loss.data.cpu().numpy()[0]
            # train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
            cnt += 1

        if opts.debug:
            dbg = labels_pred.permute(1, 2, 0).data.cpu().numpy()
            ctc_f = dbg.swapaxes(1, 2)
            labels = ctc_f.argmax(2)
            det_text, conf, dec_s, _ = print_seq_ext(labels[0, :], codec)

            print('{0} \t'.format(det_text))

        if step % disp_interval == 0:
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
            train_loss /= cnt
            train_loss_lr += train_loss
            cntt += 1
            time_now = time.time() - now
            time_total += time_now
            now = time.time()
            save_log = os.path.join(opts.save_path, 'loss_ocr.txt')
            # f = open('content/drive/My_Drive/DATA_OCR/backup/ca ca/loss.txt','a')
            f = open(save_log, 'a')
            f.write(
                'epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n'
                % (step / batch_per_epoch, step, train_loss, time_now,
                   learning_rate, cnt))
            f.close()

            print(
                'epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n'
                % (step / batch_per_epoch, step, train_loss, time_now,
                   learning_rate, cnt))

            train_loss = 0
            cnt = 1

        if step > step_start and (step % batch_per_epoch == 0):

            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
                # print(learning_rate)

            save_name = os.path.join(opts.save_path,
                                     'OCR_{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            # scheduler.step(train_loss_lr / cntt)
            # evaluate
            CER, WER = eval_ocr_crnn(val_generator1, net)
            # scheduler.step(CER)
            f = open(save_log, 'a')
            f.write(
                'time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   CER, WER))
            f.close()
            print(
                'time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f \n'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   CER, WER))
            print('save model: {}'.format(save_name))
            net.train()
            time_total = 0
            cntt = 0
            train_loss_lr = 0
def main(opts):

    model_name = 'OctGatedMLT'
    net = OctMLT(attention=True)
    acc = []

    if opts.cuda:
        net.cuda()

    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(
            args.model,
            net,
            optimizer,
            load_ocr=opts.load_ocr,
            load_detection=opts.load_detection,
            load_shared=opts.load_shared,
            load_optimizer=opts.load_optimizer,
            reset_step=opts.load_reset_step)
    else:
        learning_rate = base_lr

    step_start = 0

    net.train()

    if opts.freeze_shared:
        net_utils.freeze_shared(net)

    if opts.freeze_ocr:
        net_utils.freeze_ocr(net)

    if opts.freeze_detection:
        net_utils.freeze_detection(net)

    #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
    #acc.append([0, acc_test])
    ctc_loss = CTCLoss()

    data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
                                       batch_size=opts.batch_size,
                                       train_list=opts.train_list,
                                       in_train=True,
                                       norm_height=opts.norm_height,
                                       rgb=True)

    train_loss = 0
    cnt = 0

    for step in range(step_start, 300000):
        # batch
        images, labels, label_length = next(data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        features = net.forward_features(im_data)
        labels_pred = net.forward_ocr(features)

        # backward
        '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''

        probs_sizes = torch.IntTensor(
            [(labels_pred.permute(2, 0, 1).size()[0])] *
            (labels_pred.permute(2, 0, 1).size()[1]))
        label_sizes = torch.IntTensor(
            torch.from_numpy(np.array(label_length)).int())
        labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
        loss = ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes,
                        label_sizes) / im_data.size(0)  # change 1.9.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not np.isinf(loss.data.cpu().numpy()):
            train_loss += loss.data.cpu().numpy()[0] if isinstance(
                loss.data.cpu().numpy(), list) else loss.data.cpu().numpy(
                )  #net.bbox_loss.data.cpu().numpy()[0]
            cnt += 1

        if opts.debug:
            dbg = labels_pred.data.cpu().numpy()
            ctc_f = dbg.swapaxes(1, 2)
            labels = ctc_f.argmax(2)
            det_text, conf, dec_s = print_seq_ext(labels[0, :], codec)

            print('{0} \t'.format(det_text))

        if step % disp_interval == 0:

            train_loss /= cnt
            print('epoch %d[%d], loss: %.3f, lr: %.5f ' %
                  (step / batch_per_epoch, step, train_loss, learning_rate))

            train_loss = 0
            cnt = 0

        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))

            #acc_test, ted = test(net, codec, opts,  list_file=opts.valid_list, norm_height=opts.norm_height)
            #acc.append([0, acc_test, ted])
            np.savez('train_acc_{0}'.format(model_name), acc=acc)
Beispiel #4
0
def main(opts):
  
  model_name = 'E2E-MLT'
  net = ModelResNetSep_final(attention=True)
  acc = []
  ctc_loss = nn.CTCLoss()
  if opts.cuda:
    net.cuda()
    ctc_loss.cuda()
  optimizer = torch.optim.Adam(net.parameters(), lr=base_lr, weight_decay=weight_decay)
  scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0005, max_lr=0.001, step_size_up=3000,
                                                cycle_momentum=False)
  step_start = 0  
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  else:
    learning_rate = base_lr

  for param_group in optimizer.param_groups:
    param_group['lr'] = base_lr
    learning_rate = param_group['lr']
    print(param_group['lr'])
  
  step_start = 0  

  net.train()
  
  #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
  #acc.append([0, acc_test])
    
  # ctc_loss = CTCLoss()
  ctc_loss = nn.CTCLoss()

  data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
          batch_size=opts.batch_size, 
          train_list=opts.train_list, in_train=True, norm_height=opts.norm_height, rgb = True, normalize= True)
  
  val_dataset = ocrDataset(root=opts.valid_list, norm_height=opts.norm_height , in_train=False,is_crnn=False)
  val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False,
                                                collate_fn=alignCollate())


  # val_generator1 = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=False,
  #                                              collate_fn=alignCollate())

  cnt = 1
  cntt = 0
  train_loss_lr = 0
  time_total = 0
  train_loss = 0
  now = time.time()

  for step in range(step_start, 300000):
    # batch
    images, labels, label_length = next(data_generator)
    im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(0, 3, 1, 2)
    features = net.forward_features(im_data)
    labels_pred = net.forward_ocr(features)
    
    # backward
    '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''
    
    probs_sizes =  torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] * (labels_pred.permute(2, 0, 1).size()[1])).long()
    label_sizes = torch.IntTensor(torch.from_numpy(np.array(label_length)).int()).long()
    labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int()).long()
    loss = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data.size(0) # change 1.9.
    optimizer.zero_grad()
    loss.backward()

    clipping_value = 0.5
    torch.nn.utils.clip_grad_norm_(net.parameters(),clipping_value)
    if not (torch.isnan(loss) or torch.isinf(loss)):
      optimizer.step()
      scheduler.step()
    # if not np.isinf(loss.data.cpu().numpy()):
      train_loss += loss.data.cpu().numpy() #net.bbox_loss.data.cpu().numpy()[0]
      # train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
      cnt += 1
    
    if opts.debug:
      dbg = labels_pred.data.cpu().numpy()
      ctc_f = dbg.swapaxes(1, 2)
      labels = ctc_f.argmax(2)
      det_text, conf, dec_s,_ = print_seq_ext(labels[0, :], codec)
      
      print('{0} \t'.format(det_text))
    
    
    
    if step % disp_interval == 0:
      for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']
        
      train_loss /= cnt
      train_loss_lr += train_loss
      cntt += 1
      time_now = time.time() - now
      time_total += time_now
      now = time.time()
      save_log = os.path.join(opts.save_path, 'loss_ocr.txt')
      f = open(save_log, 'a')
      f.write(
        'epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n' % (
          step / batch_per_epoch, step, train_loss, time_now,learning_rate, cnt))
      f.close()

      print('epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n' % (
          step / batch_per_epoch, step, train_loss, time_now,learning_rate, cnt))

      train_loss = 0
      cnt = 1

    if step > step_start and (step % batch_per_epoch == 0):
      CER, WER = eval_ocr(val_generator, net)
      net.train()
      for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']
        # print(learning_rate)

      save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
      state = {'step': step,
               'learning_rate': learning_rate,
              'state_dict': net.state_dict(),
              'optimizer': optimizer.state_dict()}
      torch.save(state, save_name)
      print('save model: {}'.format(save_name))
      save_logg = os.path.join(opts.save_path, 'note_eval.txt')
      fe = open(save_logg, 'a')
      fe.write('time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f\n' % (
      step / batch_per_epoch, time_total, train_loss_lr / cntt, CER, WER))
      fe.close()
      print('time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f' % (
      step / batch_per_epoch, time_total, train_loss_lr / cntt, CER, WER))
      time_total = 0
      cntt = 0
      train_loss_lr = 0