예제 #1
0
def main(opts):

    model_name = 'OCT-E2E-MLT'
    net = OctMLT(attention=True)
    print("Using {0}".format(model_name))

    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net)

    if opts.cuda:
        net.cuda()

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_list,
                                        geo_type=opts.geo_type)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True)

    train_loss = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 0
    ctc_loss = CTCLoss()

    ctc_loss_val = 0
    box_loss_val = 0
    good_all = 0
    gt_all = 0

    best_step = step_start
    best_loss = 1000000
    best_model = net.state_dict()
    best_optimizer = optimizer.state_dict()
    best_learning_rate = learning_rate
    max_patience = 3000
    early_stop = False

    for step in range(step_start, opts.max_iters):

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

        bbox_loss += net.box_loss_value.data.cpu().numpy()
        seg_loss += net.segm_loss_value.data.cpu().numpy()
        angle_loss += net.angle_loss_value.data.cpu().numpy()

        train_loss += loss.data.cpu().numpy()
        optimizer.zero_grad()

        try:

            if step > 10000:  #this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(features)

            probs_sizes = torch.IntTensor(
                [(labels_pred.permute(2, 0, 1).size()[0])] *
                (labels_pred.permute(2, 0, 1).size()[1]))
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int())
            labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
            loss_ocr = ctc_loss(labels_pred.permute(2, 0,
                                                    1), labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            loss.backward()

            optimizer.step()
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass
        cnt += 1
        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                x_data += 1
                x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            box_loss_val /= cnt

            if train_loss < best_loss:
                best_step = step
                best_model = net.state_dict()
                best_loss = train_loss
                best_learning_rate = learning_rate
                best_optimizer = optimizer.state_dict()
            if best_step - step > max_patience:
                print("Early stopped criteria achieved.")
                save_name = os.path.join(
                    opts.save_path,
                    'BEST_{}_{}.h5'.format(model_name, best_step))
                state = {
                    'step': best_step,
                    'learning_rate': best_learning_rate,
                    'state_dict': best_model,
                    'optimizer': best_optimizer
                }
                torch.save(state, save_name)
                print('save model: {}'.format(save_name))
                opts.max_iters = step
                early_stop = True
            try:
                print(
                    'epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f in %.3f'
                    % (step / batch_per_epoch, step, train_loss, bbox_loss,
                       seg_loss, angle_loss, ctc_loss_val,
                       good_all / max(1, gt_all), end - start))
                print('max_memory_allocated {}'.format(
                    torch.cuda.max_memory_allocated()))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        #if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'max_memory_allocated': torch.cuda.max_memory_allocated()
            }
            torch.save(state, save_name)
            print('save model: {}\tmax memory: {}'.format(
                save_name, torch.cuda.max_memory_allocated()))
    if not early_stop:
        save_name = os.path.join(opts.save_path, '{}.h5'.format(model_name))
        state = {
            'step': step,
            'learning_rate': learning_rate,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        torch.save(state, save_name)
        print('save model: {}'.format(save_name))
def main(opts):
    # pairs = c1, c2, label

    model_name = 'ICCV_OCR'
    net = OCRModel()

    if opts.cuda:
        net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(args.model, net)
    else:
        learning_rate = base_lr
    print('train')
    net.train()

    # test(net)

    ctc_loss = CTCLoss(blank=0).cuda()

    data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
                                       batch_size=opts.batch_size,
                                       train_list=opts.train_list,
                                       in_train=True)

    train_loss = 0
    cnt = 0
    tq = tqdm(range(step_start, 10000000))
    for step in tq:

        # batch
        images, labels, label_length = next(data_generator)
        im_data = net_utils.np_to_variable(images,
                                           is_cuda=opts.cuda,
                                           volatile=False).permute(0, 3, 1, 2)
        labels_pred = net(im_data)

        # backward
        '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''
        torch.backends.cudnn.deterministic = True
        probs_sizes = Variable(
            torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] *
                            (labels_pred.permute(2, 0, 1).size()[1]))).long()
        label_sizes = Variable(
            torch.IntTensor(torch.from_numpy(
                np.array(label_length)).int())).long()
        labels = Variable(
            torch.IntTensor(torch.from_numpy(np.array(labels)).int())).long()
        optimizer.zero_grad()
        #probs = nn.functional.log_softmax(labels_pred, dim=94)

        labels_pred = labels_pred.permute(2, 0, 1)

        loss = ctc_loss(labels_pred, labels, probs_sizes,
                        label_sizes) / opts.batch_size  # change 1.9.
        if loss.item() == np.inf:
            continue
        #
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        cnt += 1
        # if step % disp_interval == 0:
        #     train_loss /= cnt
        #     print('epoch %d[%d], loss: %.3f, lr: %.5f ' % (
        #         step / batch_per_epoch, step, train_loss, learning_rate))
        #
        #     train_loss = 0
        #     cnt = 0
        tq.set_description(
            'epoch %d[%d], loss: %.3f, lr: %.5f ' %
            (step / batch_per_epoch, step, train_loss / cnt, learning_rate))
        #
        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))

            test(net)
예제 #3
0
def main(opts):
    model_name = 'E2E-MLT'
    # net = ModelResNetSep2(attention=True)
    net = ModelResNetSep_crnn(
        attention=True,
        multi_scale=True,
        num_classes=400,
        fixed_height=norm_height,
        net='densenet',
    )
    # net = ModelResNetSep_final(attention=True)
    print("Using {0}".format(model_name))
    ctc_loss = nn.CTCLoss()
    if opts.cuda:
        net.to(device)
        ctc_loss.to(device)
    learning_rate = opts.base_lr
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=opts.base_lr,
                                 weight_decay=weight_decay)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max', factor=0.5, patience=4, verbose=True)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                                  base_lr=0.0006,
                                                  max_lr=0.001,
                                                  step_size_up=3000,
                                                  cycle_momentum=False)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        # net_dict = net.state_dict()
        step_start, learning_rate = net_utils.load_net(args.model, net,
                                                       optimizer)
    #     step_start, learning_rate = net_utils.load_net(args.model, net, None)
    #
    #   step_start = 0
    net_utils.adjust_learning_rate(optimizer, learning_rate)

    net.train()

    data_generator = data_gen.get_batch(num_workers=opts.num_readers,
                                        input_size=opts.input_size,
                                        batch_size=opts.batch_size,
                                        train_list=opts.train_path,
                                        geo_type=opts.geo_type,
                                        normalize=opts.normalize)

    dg_ocr = ocr_gen.get_batch(num_workers=2,
                               batch_size=opts.ocr_batch_size,
                               train_list=opts.ocr_feed_list,
                               in_train=True,
                               norm_height=norm_height,
                               rgb=True,
                               normalize=opts.normalize)

    # e2edata = E2Edataset(train_list=opts.eval_path, normalize= opts.normalize)
    # e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate
    #                                           )

    train_loss = 0
    train_loss_temp = 0
    bbox_loss, seg_loss, angle_loss = 0., 0., 0.
    cnt = 1

    # ctc_loss = CTCLoss()

    ctc_loss_val = 0
    ctc_loss_val2 = 0
    ctcl = torch.tensor([0])
    box_loss_val = 0
    good_all = 0
    gt_all = 0
    train_loss_lr = 0
    cntt = 0
    time_total = 0
    now = time.time()

    for step in range(step_start, opts.max_iters):
        # scheduler.batch_step()

        # batch
        images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(
            data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        start = timeit.timeit()
        # cv2.imshow('img', images)
        try:
            seg_pred, roi_pred, angle_pred, features = net(im_data)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue
        end = timeit.timeit()

        # backward

        smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
        training_mask_var = net_utils.np_to_variable(training_masks,
                                                     is_cuda=opts.cuda)
        angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4],
                                            is_cuda=opts.cuda)
        geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]],
                                          is_cuda=opts.cuda)

        try:
            # ? loss
            loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred,
                            angle_gt, roi_pred, geo_gt)
        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            continue

            # @ loss_val
        if not (torch.isnan(loss) or torch.isinf(loss)):
            train_loss_temp += loss.data.cpu().numpy()

        optimizer.zero_grad()

        try:

            if step > 1000 or True:  # this is just extra augumentation step ... in early stage just slows down training
                ctcl, gt_b_good, gt_b_all = process_boxes(images,
                                                          im_data,
                                                          seg_pred[0],
                                                          roi_pred[0],
                                                          angle_pred[0],
                                                          score_maps,
                                                          gt_idxs,
                                                          gtso,
                                                          lbso,
                                                          features,
                                                          net,
                                                          ctc_loss,
                                                          opts,
                                                          debug=opts.debug)

                # ? loss
                loss = loss + ctcl
                gt_all += gt_b_all
                good_all += gt_b_good

            imageso, labels, label_length = next(dg_ocr)
            im_data_ocr = net_utils.np_to_variable(imageso,
                                                   is_cuda=opts.cuda).permute(
                                                       0, 3, 1, 2)
            # features = net.forward_features(im_data_ocr)
            labels_pred = net.forward_ocr(im_data_ocr)

            probs_sizes = torch.IntTensor([(labels_pred.size()[0])] *
                                          (labels_pred.size()[1])).long()
            label_sizes = torch.IntTensor(
                torch.from_numpy(np.array(label_length)).int()).long()
            labels = torch.IntTensor(torch.from_numpy(
                np.array(labels)).int()).long()
            loss_ocr = ctc_loss(labels_pred, labels, probs_sizes,
                                label_sizes) / im_data_ocr.size(0) * 0.5

            loss_ocr.backward()
            # @ loss_val
            # ctc_loss_val2 += loss_ocr.item()

            loss.backward()

            clipping_value = 0.5
            torch.nn.utils.clip_grad_norm_(net.parameters(), clipping_value)
            if opts.d1:
                print('loss_nan', torch.isnan(loss))
                print('loss_inf', torch.isinf(loss))
                print('lossocr_nan', torch.isnan(loss_ocr))
                print('lossocr_inf', torch.isinf(loss_ocr))

            if not (torch.isnan(loss) or torch.isinf(loss)
                    or torch.isnan(loss_ocr) or torch.isinf(loss_ocr)):
                bbox_loss += net.box_loss_value.data.cpu().numpy()
                seg_loss += net.segm_loss_value.data.cpu().numpy()
                angle_loss += net.angle_loss_value.data.cpu().numpy()
                train_loss += train_loss_temp
                ctc_loss_val2 += loss_ocr.item()
                ctc_loss_val += ctcl.data.cpu().numpy()[0]
                # train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
                optimizer.step()
                scheduler.step()
                train_loss_temp = 0
                cnt += 1

        except:
            import sys, traceback
            traceback.print_exc(file=sys.stdout)
            pass

        if step % disp_interval == 0:

            if opts.debug:

                segm = seg_pred[0].data.cpu()[0].numpy()
                segm = segm.squeeze(0)
                cv2.imshow('segm_map', segm)

                segm_res = cv2.resize(score_maps[0],
                                      (images.shape[2], images.shape[1]))
                mask = np.argwhere(segm_res > 0)

                x_data = im_data.data.cpu().numpy()[0]
                x_data = x_data.swapaxes(0, 2)
                x_data = x_data.swapaxes(0, 1)

                if opts.normalize:
                    x_data += 1
                    x_data *= 128
                x_data = np.asarray(x_data, dtype=np.uint8)
                x_data = x_data[:, :, ::-1]

                im_show = x_data
                try:
                    im_show[mask[:, 0], mask[:, 1], 1] = 255
                    im_show[mask[:, 0], mask[:, 1], 0] = 0
                    im_show[mask[:, 0], mask[:, 1], 2] = 0
                except:
                    pass

                cv2.imshow('img0', im_show)
                cv2.imshow('score_maps', score_maps[0] * 255)
                cv2.imshow('train_mask', training_masks[0] * 255)
                cv2.waitKey(10)

            train_loss /= cnt
            bbox_loss /= cnt
            seg_loss /= cnt
            angle_loss /= cnt
            ctc_loss_val /= cnt
            ctc_loss_val2 /= cnt
            box_loss_val /= cnt
            train_loss_lr += (train_loss)

            cntt += 1
            time_now = time.time() - now
            time_total += time_now
            now = time.time()
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
            save_log = os.path.join(opts.save_path, 'loss.txt')

            f = open(save_log, 'a')
            f.write(
                'epoch %d[%d], lr: %f, loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f, lv2: %.3f, time: %.2f s, cnt: %d\n'
                % (step / batch_per_epoch, step, learning_rate, train_loss,
                   bbox_loss, seg_loss, angle_loss, ctc_loss_val,
                   good_all / max(1, gt_all), ctc_loss_val2, time_now, cnt))
            f.close()
            try:

                print(
                    'epoch %d[%d], lr: %f, loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, rec: %.5f, lv2: %.3f, time: %.2f s, cnt: %d\n'
                    %
                    (step / batch_per_epoch, step, learning_rate, train_loss,
                     bbox_loss, seg_loss, angle_loss, ctc_loss_val,
                     good_all / max(1, gt_all), ctc_loss_val2, time_now, cnt))
            except:
                import sys, traceback
                traceback.print_exc(file=sys.stdout)
                pass

            train_loss = 0
            bbox_loss, seg_loss, angle_loss = 0., 0., 0.
            cnt = 0
            ctc_loss_val = 0
            ctc_loss_val2 = 0
            good_all = 0
            gt_all = 0
            box_loss_val = 0

        # if step % valid_interval == 0:
        #  validate(opts.valid_list, net)
        if step > step_start and (step % batch_per_epoch == 0):
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
                print('learning_rate', learning_rate)
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            #evaluate
            re_tpe2e, re_tp, re_e1, precision = evaluate_e2e_crnn(
                root=args.eval_path,
                net=net,
                norm_height=norm_height,
                name_model=save_name,
                normalize=args.normalize,
                save_dir=args.save_path)
            # CER,WER = evaluate_crnn(e2edataloader,net)

            # scheduler.step(re_tpe2e)
            f = open(save_log, 'a')
            f.write(
                'time epoch [%d]: %.2f s, loss_total: %.3f, lr:%f, re_tpe2e = %f, re_tp = %f, re_e1 = %f, precision = %f\n'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   learning_rate, re_tpe2e, re_tp, re_e1, precision))
            f.close()
            print(
                'time epoch [%d]: %.2f s, loss_total: %.3f, re_tpe2e = %f, re_tp = %f, re_e1 = %f, precision = %f'
                % (step / batch_per_epoch, time_total, train_loss_lr / cntt,
                   re_tpe2e, re_tp, re_e1, precision))
            #print('time epoch [%d]: %.2f s, loss_total: %.3f' % (step / batch_per_epoch, time_total,train_loss_lr/cntt))
            print('save model: {}'.format(save_name))
            time_total = 0
            cntt = 0
            train_loss_lr = 0
            net.train()
예제 #4
0
def main(opts):

    model_name = 'OctGatedMLT'
    net = OctMLT(attention=True)
    acc = []

    if opts.cuda:
        net.cuda()

    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=base_lr,
                                 weight_decay=weight_decay)
    step_start = 0
    if os.path.exists(opts.model):
        print('loading model from %s' % args.model)
        step_start, learning_rate = net_utils.load_net(
            args.model,
            net,
            optimizer,
            load_ocr=opts.load_ocr,
            load_detection=opts.load_detection,
            load_shared=opts.load_shared,
            load_optimizer=opts.load_optimizer,
            reset_step=opts.load_reset_step)
    else:
        learning_rate = base_lr

    step_start = 0

    net.train()

    if opts.freeze_shared:
        net_utils.freeze_shared(net)

    if opts.freeze_ocr:
        net_utils.freeze_ocr(net)

    if opts.freeze_detection:
        net_utils.freeze_detection(net)

    #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
    #acc.append([0, acc_test])
    ctc_loss = CTCLoss()

    data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
                                       batch_size=opts.batch_size,
                                       train_list=opts.train_list,
                                       in_train=True,
                                       norm_height=opts.norm_height,
                                       rgb=True)

    train_loss = 0
    cnt = 0

    for step in range(step_start, 300000):
        # batch
        images, labels, label_length = next(data_generator)
        im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(
            0, 3, 1, 2)
        features = net.forward_features(im_data)
        labels_pred = net.forward_ocr(features)

        # backward
        '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''

        probs_sizes = torch.IntTensor(
            [(labels_pred.permute(2, 0, 1).size()[0])] *
            (labels_pred.permute(2, 0, 1).size()[1]))
        label_sizes = torch.IntTensor(
            torch.from_numpy(np.array(label_length)).int())
        labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int())
        loss = ctc_loss(labels_pred.permute(2, 0, 1), labels, probs_sizes,
                        label_sizes) / im_data.size(0)  # change 1.9.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not np.isinf(loss.data.cpu().numpy()):
            train_loss += loss.data.cpu().numpy()[0] if isinstance(
                loss.data.cpu().numpy(), list) else loss.data.cpu().numpy(
                )  #net.bbox_loss.data.cpu().numpy()[0]
            cnt += 1

        if opts.debug:
            dbg = labels_pred.data.cpu().numpy()
            ctc_f = dbg.swapaxes(1, 2)
            labels = ctc_f.argmax(2)
            det_text, conf, dec_s = print_seq_ext(labels[0, :], codec)

            print('{0} \t'.format(det_text))

        if step % disp_interval == 0:

            train_loss /= cnt
            print('epoch %d[%d], loss: %.3f, lr: %.5f ' %
                  (step / batch_per_epoch, step, train_loss, learning_rate))

            train_loss = 0
            cnt = 0

        if step > step_start and (step % batch_per_epoch == 0):
            save_name = os.path.join(opts.save_path,
                                     '{}_{}.h5'.format(model_name, step))
            state = {
                'step': step,
                'learning_rate': learning_rate,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, save_name)
            print('save model: {}'.format(save_name))

            #acc_test, ted = test(net, codec, opts,  list_file=opts.valid_list, norm_height=opts.norm_height)
            #acc.append([0, acc_test, ted])
            np.savez('train_acc_{0}'.format(model_name), acc=acc)
예제 #5
0
def main(opts):
  
  model_name = 'E2E-MLT'
  net = ModelResNetSep_final(attention=True)
  acc = []
  ctc_loss = nn.CTCLoss()
  if opts.cuda:
    net.cuda()
    ctc_loss.cuda()
  optimizer = torch.optim.Adam(net.parameters(), lr=base_lr, weight_decay=weight_decay)
  scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0005, max_lr=0.001, step_size_up=3000,
                                                cycle_momentum=False)
  step_start = 0  
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  else:
    learning_rate = base_lr

  for param_group in optimizer.param_groups:
    param_group['lr'] = base_lr
    learning_rate = param_group['lr']
    print(param_group['lr'])
  
  step_start = 0  

  net.train()
  
  #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
  #acc.append([0, acc_test])
    
  # ctc_loss = CTCLoss()
  ctc_loss = nn.CTCLoss()

  data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
          batch_size=opts.batch_size, 
          train_list=opts.train_list, in_train=True, norm_height=opts.norm_height, rgb = True, normalize= True)
  
  val_dataset = ocrDataset(root=opts.valid_list, norm_height=opts.norm_height , in_train=False,is_crnn=False)
  val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False,
                                                collate_fn=alignCollate())


  # val_generator1 = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=False,
  #                                              collate_fn=alignCollate())

  cnt = 1
  cntt = 0
  train_loss_lr = 0
  time_total = 0
  train_loss = 0
  now = time.time()

  for step in range(step_start, 300000):
    # batch
    images, labels, label_length = next(data_generator)
    im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(0, 3, 1, 2)
    features = net.forward_features(im_data)
    labels_pred = net.forward_ocr(features)
    
    # backward
    '''
    acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
    '''
    
    probs_sizes =  torch.IntTensor([(labels_pred.permute(2, 0, 1).size()[0])] * (labels_pred.permute(2, 0, 1).size()[1])).long()
    label_sizes = torch.IntTensor(torch.from_numpy(np.array(label_length)).int()).long()
    labels = torch.IntTensor(torch.from_numpy(np.array(labels)).int()).long()
    loss = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data.size(0) # change 1.9.
    optimizer.zero_grad()
    loss.backward()

    clipping_value = 0.5
    torch.nn.utils.clip_grad_norm_(net.parameters(),clipping_value)
    if not (torch.isnan(loss) or torch.isinf(loss)):
      optimizer.step()
      scheduler.step()
    # if not np.isinf(loss.data.cpu().numpy()):
      train_loss += loss.data.cpu().numpy() #net.bbox_loss.data.cpu().numpy()[0]
      # train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
      cnt += 1
    
    if opts.debug:
      dbg = labels_pred.data.cpu().numpy()
      ctc_f = dbg.swapaxes(1, 2)
      labels = ctc_f.argmax(2)
      det_text, conf, dec_s,_ = print_seq_ext(labels[0, :], codec)
      
      print('{0} \t'.format(det_text))
    
    
    
    if step % disp_interval == 0:
      for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']
        
      train_loss /= cnt
      train_loss_lr += train_loss
      cntt += 1
      time_now = time.time() - now
      time_total += time_now
      now = time.time()
      save_log = os.path.join(opts.save_path, 'loss_ocr.txt')
      f = open(save_log, 'a')
      f.write(
        'epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n' % (
          step / batch_per_epoch, step, train_loss, time_now,learning_rate, cnt))
      f.close()

      print('epoch %d[%d], loss_ctc: %.3f,time: %.2f s, lr: %.5f, cnt: %d\n' % (
          step / batch_per_epoch, step, train_loss, time_now,learning_rate, cnt))

      train_loss = 0
      cnt = 1

    if step > step_start and (step % batch_per_epoch == 0):
      CER, WER = eval_ocr(val_generator, net)
      net.train()
      for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']
        # print(learning_rate)

      save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
      state = {'step': step,
               'learning_rate': learning_rate,
              'state_dict': net.state_dict(),
              'optimizer': optimizer.state_dict()}
      torch.save(state, save_name)
      print('save model: {}'.format(save_name))
      save_logg = os.path.join(opts.save_path, 'note_eval.txt')
      fe = open(save_logg, 'a')
      fe.write('time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f\n' % (
      step / batch_per_epoch, time_total, train_loss_lr / cntt, CER, WER))
      fe.close()
      print('time epoch [%d]: %.2f s, loss_total: %.3f, CER = %f, WER = %f' % (
      step / batch_per_epoch, time_total, train_loss_lr / cntt, CER, WER))
      time_total = 0
      cntt = 0
      train_loss_lr = 0