예제 #1
0
def main(opts):

    model_name = 'E2E-MLT'
    net = ModelResNetSep2(attention=True)
    print("Using {0}".format(model_name))

    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net)

    if opts.cuda:
        net.cuda()

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_list,
                                        geo_type=opts.geo_type)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True)

    train_loss = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 0
    ctc_loss = CTCLoss()

    ctc_loss_val = 0
    box_loss_val = 0
    good_all = 0
    gt_all = 0

    for step in range(step_start, opts.max_iters):

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

        bbox_loss += net.box_loss_value.data.cpu().numpy()
        seg_loss += net.segm_loss_value.data.cpu().numpy()
        angle_loss += net.angle_loss_value.data.cpu().numpy()

        train_loss += loss.data.cpu().numpy()
        optimizer.zero_grad()

        try:

            if step > 100000:  #this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(features)

            probs_sizes = torch.IntTensor(
                [(labels_pred.permute(2, 0, 1).size()[0])] *
                (labels_pred.permute(2, 0, 1).size()[1]))
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int())
            labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
            loss_ocr = ctc_loss(labels_pred.permute(2, 0,
                                                    1), labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            loss.backward()

            optimizer.step()
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass
        cnt += 1
        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                x_data += 1
                x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            box_loss_val /= cnt
            try:
                print(
                    'epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f in %.3f'
                    % (step / batch_per_epoch, step, train_loss, bbox_loss,
                       seg_loss, angle_loss, ctc_loss_val,
                       good_all / max(1, gt_all), end - start))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        #if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))
예제 #2
0
    return scaled, (resize_h, resize_w)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-cuda', type=int, default=1)
    parser.add_argument('-model', default='e2e-mlt.h5')
    parser.add_argument('-image', default='image.jpg')
    parser.add_argument('-segm_thresh', default=0.5)

    font2 = ImageFont.truetype("Arial-Unicode-Regular.ttf", 18)

    args = parser.parse_args()

    net = ModelResNetSep2(attention=True)
    net_utils.load_net(args.model, net)
    net = net.eval()

    if args.cuda:
        print('Using cuda ...')
        net = net.cuda()

    frame_no = 0
    with torch.no_grad():
        while ret:
            im = cv2.imread("/content/1f2830794e9aec999772b750a53aded2.jpg")

            if im is not None:
                im_resized, (ratio_h, ratio_w) = resize_image(im,
                                                              scale_up=False)