Exemplo n.º 1
0
def main():
    aug = aug_aocr(args)
    datasets = data.fetch_data(args, args.datasets, batch_size=args.batch_size_per_gpu,
                              batch_size_val=args.batch_size_per_gpu_val, k_fold=1, split_val=0.1,
                               pre_process=None, aug=aug)

    for idx, (train_set, val_set) in enumerate(datasets):
        losses = []
        lev_dises, str_accus = [], []
        print("\n =============== Cross Validation: %s/%s ================ " %
                  (idx + 1, len(datasets)))
        # Prepare Network
        encoder = att_model.Attn_CNN(backbone_require_grad=True)
        decoder = att_model.AttnDecoder(args)
        encoder.apply(init.init_rnn).apply(init.init_others)
        decoder.apply(init.init_rnn).apply(init.init_others)
        criterion = nn.NLLLoss()
        encoder = torch.nn.DataParallel(encoder).cuda()
        decoder = torch.nn.DataParallel(decoder).cuda()
        torch.backends.cudnn.benchmark = True
        if args.finetune:
            encoder, decoder = util.load_latest_model(args, [encoder, decoder],
                                                      prefix=["encoder", "decoder"], strict=False)
        
        # Prepare loss function and optimizer
        encoder_optimizer = AdaBound(encoder.parameters(), lr=args.learning_rate,
                                     final_lr=args.learning_rate * 10, weight_decay=args.weight_decay)
        decoder_optimizer = AdaBound(decoder.parameters(), lr=args.learning_rate,
                                     final_lr=args.learning_rate * 10, weight_decay=args.weight_decay)

        for epoch in range(args.epoch_num):
            loss = fit(args, encoder, decoder, train_set, encoder_optimizer,
                       decoder_optimizer, criterion, is_train=True)
            losses.append(loss)
            train_losses = [np.asarray(losses)]
            if val_set is not None:
                lev_dis, str_accu = fit(args, encoder, decoder, val_set, encoder_optimizer,
                                        decoder_optimizer, criterion, is_train=False)
                lev_dises.append(lev_dis)
                str_accus.append(str_accu)
                val_scores = [np.asarray(lev_dises), np.asarray(str_accus)]
            if epoch % 5 == 0:
                util.save_model(args, args.curr_epoch, encoder.state_dict(), prefix="encoder",
                                keep_latest=20)
                util.save_model(args, args.curr_epoch, decoder.state_dict(), prefix="decoder",
                                keep_latest=20)
            if epoch > 4:
                vb.plot_multi_loss_distribution(
                    multi_line_data= [train_losses, val_scores],
                    multi_line_labels= [["NLL Loss"], ["Levenstein", "String-Level"]],
                    save_path = args.loss_log, window=5, name = dt,
                    bound=[None, {"low": 0.0, "high": 100.0}],
                    titles=["Train Loss", "Validation Score"]
                )
Exemplo n.º 2
0
def test():
    aug = aug_test(args)
    dataset = data.fetch_detection_data(args,
                                        sources=args.test_sources,
                                        k_fold=1,
                                        auxiliary_info=args.test_aux,
                                        aug=aug,
                                        batch_size=1 /
                                        torch.cuda.device_count(),
                                        shuffle=False)[0][0]
    net = model.SSD(cfg,
                    connect_loc_to_conf=args.loc_to_conf,
                    fix_size=args.fix_size,
                    conf_incep=args.conf_incep,
                    loc_incep=args.loc_incep,
                    nms_thres=args.nms_threshold,
                    loc_preconv=args.loc_preconv,
                    conf_preconv=args.conf_preconv,
                    FPN=args.feature_pyramid_net,
                    SA=args.self_attention,
                    in_wid=args.inner_filters,
                    m_factor=args.inner_m_factor)
    net = torch.nn.DataParallel(net, device_ids=[0], output_device=0).cuda()
    net = util.load_latest_model(args,
                                 net,
                                 prefix=args.model_prefix_finetune,
                                 strict=True)
    detector = model.Detect(num_classes=2,
                            bkg_label=0,
                            top_k=args.detector_top_k,
                            conf_thresh=args.detector_conf_threshold,
                            nms_thresh=args.detector_nms_threshold)
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(dataset):
            images = images.cuda()
            print(images.shape)
            targets = [ann.cuda() for ann in targets]
            ratios = images.size(3) / images.size(2)
            if ratios != 1.0:
                print(ratios)
            out = net(images, is_train=False)
            loc_data, conf_data, prior_data = out
            #prior_data = prior_data.to("cuda:%d" % (device_id))
            det_result = detector(loc_data, conf_data, prior_data)
            eval_result = evaluate(images,
                                   det_result.data,
                                   targets,
                                   batch_idx,
                                   args.threshold,
                                   visualize=True,
                                   post_combine=True)
            print(eval_result)
Exemplo n.º 3
0
def test():
    args.train = False
    dataset = data.fetch_probaV_data(args,
                                     sources=args.test_sources,
                                     shuffle=False,
                                     batch_size=1 / torch.cuda.device_count(),
                                     auxiliary_info=[2, 2])[0][0]
    if args.which_model.lower() == "carn":
        net = model.CARN(args.n_selected_img,
                         args.filters,
                         3,
                         s_MSE=True,
                         trellis=args.trellis)
    elif args.which_model.lower() == "rdn":
        net = model.RDN(args.n_selected_img,
                        3,
                        3,
                        filters=args.filters,
                        s_MSE=True,
                        group=args.n_selected_img,
                        trellis=args.trellis)
    elif args.which_model.lower() == "basic":
        net = model.ProbaV_basic(inchannel=args.n_selected_img)
    else:
        print("args.which_model or -wm should be one of [carn, rdn, basic], "
              "your -wm %s is illegal, and switched to 'basic' automatically" %
              (args.which_model.lower()))
        net = model.ProbaV_basic(inchannel=args.n_selected_img)
    net = torch.nn.DataParallel(net, device_ids=[0], output_device=0).cuda()
    torch.backends.cudnn.benchmark = True
    net = util.load_latest_model(args,
                                 net,
                                 prefix=args.model_prefix_finetune,
                                 strict=True)
    with torch.no_grad():
        for batch_idx, (images, blend_target, unblend_target,
                        norm) in enumerate(dataset):
            print(batch_idx)
            images, blend_target = images.cuda(), blend_target.cuda(),
            prediction, mae, s_mse = net(images, blend_target, train=False)
            pred = vb.plot_tensor(args, prediction, margin=0)
            cv2.imwrite(
                os.path.expanduser("~/Pictures/result/%s.jpg" %
                                   str(batch_idx).zfill(4)),
                pred / 65536 * 255)
Exemplo n.º 4
0
def main():
    test_folder = os.path.expanduser(
        "~/Pictures/dataset/reid/eval_lp/HHR_Body")
    diff_person_choose = 4
    diff_view = 8

    net = model.Encoder(args)
    net = torch.nn.DataParallel(net).cuda()
    net.eval()
    net = util.load_latest_model(args, net, prefix=args.model_prefix, nth=1)

    #test_data, label_mapping = data.fetch_dataset(args, verbose=False, for_test=True)
    person_id = os.listdir(test_folder)
    gt_label, views = [], []
    for p_id in sorted(person_id):
        views += random.sample(glob.glob(os.path.join(test_folder, p_id, "*")),
                               diff_view)

    test_imgs = load_img(views, augment())
    gt_label = torch.arange(len(person_id)).unsqueeze(-1).repeat(
        1, diff_view).view(-1).int()

    with torch.no_grad():
        test(net, test_imgs, gt_label)
Exemplo n.º 5
0
def test_rotation(opt):
    result_dir = os.path.join(args.path, args.code_name,
                              "result+" + "-".join(opt.model_prefix_list))
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    # Load
    assert len(opt.model_prefix_list) <= torch.cuda.device_count(), \
        "number of models should not exceed the device numbers"
    nets = []
    for _, prefix in enumerate(opt.model_prefix_list):
        net = model.SSD(cfg,
                        connect_loc_to_conf=True,
                        fix_size=False,
                        incep_conf=True,
                        incep_loc=True)
        device_id = opt.device_id if len(opt.model_prefix_list) == 1 else _
        net = net.to("cuda:%d" % (device_id))
        net_dict = net.state_dict()
        weight_dict = util.load_latest_model(args,
                                             net,
                                             prefix=prefix,
                                             return_state_dict=True,
                                             nth=opt.nth_best_model)
        loading_fail_signal = False
        for i, key in enumerate(net_dict.keys()):
            if "module." + key not in weight_dict:
                net_dict[key] = torch.zeros(net_dict[key].shape)
        for key in weight_dict.keys():
            if key[7:] in net_dict:
                if net_dict[key[7:]].shape == weight_dict[key].shape:
                    net_dict[key[7:]] = weight_dict[key]
                else:
                    print(
                        "Key: %s from disk has shape %s copy to the model with shape %s"
                        % (key[7:], str(weight_dict[key].shape),
                           str(net_dict[key[7:]].shape)))
                    loading_fail_signal = True
            else:
                print("Key: %s does not exist in net_dict" % (key[7:]))
        if loading_fail_signal:
            raise RuntimeError(
                'Shape Error happens, remove "%s" from your -mpl settings.' %
                (prefix))

        net.load_state_dict(net_dict)
        net.eval()
        nets.append(net.half())
        print("Above model loaded with out a problem")
    detector = model.Detect(num_classes=2,
                            bkg_label=0,
                            top_k=opt.detector_top_k,
                            conf_thresh=opt.detector_conf_threshold,
                            nms_thresh=opt.detector_nms_threshold)

    # Enumerate test folder
    root_path = os.path.expanduser(opt.test_dataset_root)
    if not os.path.exists(root_path):
        raise FileNotFoundError(
            "%s does not exists, please check your -tdr/--test_dataset_root settings"
            % (root_path))
    img_list = glob.glob(root_path + "/*.%s" % (opt.extension))
    precisions, recalls = [], []
    for i, img_file in enumerate(sorted(img_list)):
        start = time.time()
        name = img_file[img_file.rfind("/") + 1:-4]
        img = cv2.imread(img_file)
        height_ori, width_ori = img.shape[0], img.shape[1]

        # detect rotation for returning the image back
        transform_det = {"rotation": 0}
        # Resize the longer side to a certain length
        if height_ori >= width_ori:
            resize_aug = augmenters.Sequential([
                augmenters.Resize(size={
                    "height": opt.test_size,
                    "width": "keep-aspect-ratio"
                })
            ])
        else:
            resize_aug = augmenters.Sequential([
                augmenters.Resize(size={
                    "height": "keep-aspect-ratio",
                    "width": opt.test_size
                })
            ])
        resize_aug = resize_aug.to_deterministic()
        image = resize_aug.augment_image(img)
        h_re, w_re = image.shape[0], image.shape[1]
        # Pad the image into a square image
        pad_aug = augmenters.Sequential(
            augmenters.PadToFixedSize(width=opt.test_size,
                                      height=opt.test_size,
                                      pad_cval=255,
                                      position="center"))
        pad_aug = pad_aug.to_deterministic()
        image = pad_aug.augment_image(image)
        h_final, w_final = image.shape[0], image.shape[1]

        # Prepare image tensor and test
        image_t = torch.Tensor(util.normalize_image(args, image)).unsqueeze(0)
        image_t = image_t.permute(0, 3, 1, 2)
        #visualize_bbox(args, cfg, image, [torch.Tensor(rot_coord).cuda()], net.prior, height_final/width_final)

        text_boxes = []
        for _, net in enumerate(nets):
            device_id = opt.device_id if len(nets) == 1 else _
            image_t = image_t.to("cuda:%d" % (device_id)).half()
            out = net(image_t, is_train=False)
            loc_data, conf_data, prior_data = out
            prior_data = prior_data.to("cuda:%d" % (device_id))
            det_result = detector(loc_data, conf_data, prior_data)
            # Extract the predicted bboxes
            idx = det_result.data[0, 1, :, 0] >= 0.1
            text_boxes.append(det_result.data[0, 1, idx, 1:])
        text_boxes = torch.cat(text_boxes, dim=0)
        text_boxes = combine_boxes(text_boxes, img=image_t)
        pred = [[float(coor) for coor in area] for area in text_boxes]
        BBox = [
            imgaug.augmentables.bbs.BoundingBox(box[0] * w_final,
                                                box[1] * h_final,
                                                box[2] * w_final,
                                                box[3] * h_final)
            for box in pred
        ]
        BBoxes = imgaug.augmentables.bbs.BoundingBoxesOnImage(
            BBox, shape=image.shape)
        return_aug = augment_back(transform_det, height_ori, width_ori,
                                  (h_final - h_re) / 2, (w_final - w_re) / 2)
        return_aug = return_aug.to_deterministic()
        img_ori = return_aug.augment_image(image)
        bbox = return_aug.augment_bounding_boxes([BBoxes])[0]

        f = open(os.path.join(result_dir, name + ".txt"), "w")
        pred_final = []
        for box in bbox.bounding_boxes:
            x1, y1, x2, y2 = int(round(box.x1)), int(round(box.y1)), int(
                round(box.x2)), int(round(box.y2))
            pred_final.append([x1, y1, x2, y2])
            #box_tensors.append(torch.tensor([x1, y1, x2, y2]))
            # 4-point to 8-point: x1, y1, x2, y1, x2, y2, x1, y2
            f.write("%d,%d,%d,%d,%d,%d,%d,%d\n" %
                    (x1, y1, x2, y1, x2, y2, x1, y2))
            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 105, 65), 2)
        #accu, precision, recall = measure(torch.Tensor(pred_final).cuda(), torch.Tensor(gt_coords).cuda(),
        #width=img.shape[1], height=img.shape[0])
        img_save_directory = os.path.join(
            args.path, args.code_name,
            "val+" + "-".join(opt.model_prefix_list))
        if not os.path.exists(img_save_directory):
            os.mkdir(img_save_directory)
        _imgh, _imgw, _imgc = img.shape
        _imgh = _imgh * opt.test_size / _imgw
        img = cv2.resize(img, (opt.test_size, int(_imgh)))
        #cv2.imwrite(os.path.join(img_save_directory, name + ".jpg"), img)
        cv2.imwrite(os.path.join(img_save_directory, "%04d.jpg" % i), img)
        f.close()
        print("%d th image cost %.2f seconds" % (i, time.time() - start))
Exemplo n.º 6
0
}

def init_weight(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight)
        #torch.nn.init.xavier_uniform_(m.bias)
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_normal_(m.weight)
        #torch.nn.init.kaiming_uniform_(m.bias)

if __name__ == "__main__":
    args = util.get_args(presets.PRESET)
    with torch.cuda.device(2):
        net = model.CifarNet_Vanilla()
        if args.finetune:
            net = util.load_latest_model(args, net)
        else:
            #net.apply(init_weight)
            keras_model = get_keras_model()
            model_path = os.path.join(os.getcwd(), 'test', 'models', "cifar10_cnn.h5")
            net = weight_transfer.initialize_with_keras_hdf5(keras_model, map_dict, net, model_path)
            omth_util.save_model(args, args.curr_epoch, net.state_dict())
        #net.to(args.device)
        net.cuda()
        #summary(net, input_size=(3, 32, 32), device=device)

        #train_set = fetch_data(args, [("data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5")])
        #test_set = fetch_data(args, ["test_batch"])

        transform = transforms.Compose([
            transforms.ToTensor(),
Exemplo n.º 7
0
def main():
    aug = aug_temp(args)
    dt = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M")
    datasets = data.fetch_detection_data(args,
                                         sources=args.train_sources,
                                         k_fold=1,
                                         batch_size=args.batch_size_per_gpu,
                                         batch_size_val=1,
                                         auxiliary_info=args.train_aux,
                                         split_val=0.1,
                                         aug=aug)
    for idx, (train_set, val_set) in enumerate(datasets):
        loc_loss, conf_loss = [], []
        accuracy, precision, recall, f1_score = [], [], [], []
        print("\n =============== Cross Validation: %s/%s ================ " %
              (idx + 1, len(datasets)))
        net = model.SSD(cfg,
                        connect_loc_to_conf=args.loc_to_conf,
                        fix_size=args.fix_size,
                        conf_incep=args.conf_incep,
                        loc_incep=args.loc_incep,
                        nms_thres=args.nms_threshold,
                        loc_preconv=args.loc_preconv,
                        conf_preconv=args.conf_preconv,
                        FPN=args.feature_pyramid_net,
                        SA=args.self_attention,
                        in_wid=args.inner_filters,
                        m_factor=args.inner_m_factor)
        net = torch.nn.DataParallel(net,
                                    device_ids=args.gpu_id,
                                    output_device=args.output_gpu_id).cuda()
        detector = model.Detect(num_classes=2,
                                bkg_label=0,
                                top_k=800,
                                conf_thresh=0.05,
                                nms_thresh=0.3)
        #detector = None
        # Input dimension of bbox is different in each step
        torch.backends.cudnn.benchmark = True
        if args.fix_size:
            net.module.prior = net.module.prior.cuda()
        if args.finetune:
            net = util.load_latest_model(args,
                                         net,
                                         prefix=args.model_prefix_finetune,
                                         strict=True)
        # Using the latest optimizer, better than Adam and SGD
        optimizer = AdaBound(
            net.parameters(),
            lr=args.learning_rate,
            final_lr=20 * args.learning_rate,
            weight_decay=args.weight_decay,
        )

        for epoch in range(args.epoch_num):
            loc_avg, conf_avg = fit(args,
                                    cfg,
                                    net,
                                    detector,
                                    train_set,
                                    optimizer,
                                    is_train=True)
            loc_loss.append(loc_avg)
            conf_loss.append(conf_avg)
            train_losses = [np.asarray(loc_loss), np.asarray(conf_loss)]
            if val_set is not None:
                accu, pre, rec, f1 = fit(args,
                                         cfg,
                                         net,
                                         detector,
                                         val_set,
                                         optimizer,
                                         is_train=False)
                accuracy.append(accu)
                precision.append(pre)
                recall.append(rec)
                f1_score.append(f1)
                val_losses = [
                    np.asarray(accuracy),
                    np.asarray(precision),
                    np.asarray(recall),
                    np.asarray(f1_score)
                ]
            if epoch != 0 and epoch % 10 == 0:
                util.save_model(args,
                                args.curr_epoch,
                                net.state_dict(),
                                prefix=args.model_prefix,
                                keep_latest=3)
            if epoch > 5:
                vb.plot_multi_loss_distribution(
                    multi_line_data=[train_losses, val_losses],
                    multi_line_labels=[["location", "confidence"],
                                       [
                                           "Accuracy", "Precision", "Recall",
                                           "F1-Score"
                                       ]],
                    save_path=args.loss_log,
                    window=5,
                    name=dt,
                    bound=[{
                        "low": 0.0,
                        "high": 3.0
                    }, {
                        "low": 0.0,
                        "high": 1.0
                    }],
                    titles=["Train Loss", "Validation Score"])
        # Clean the data for next cross validation
        del net, optimizer
        args.curr_epoch = 0
Exemplo n.º 8
0
def main():
    dt = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M")
    datasets = data.fetch_probaV_data(args,
                                      sources=args.train_sources,
                                      k_fold=args.cross_val,
                                      split_val=0.1,
                                      batch_size=args.batch_size_per_gpu,
                                      auxiliary_info=[2, 2])
    for idx, (train_set, val_set) in enumerate(datasets):
        Loss, Measure = [], []
        val_Loss, val_Measure = [], []
        print("\n =============== Cross Validation: %s/%s ================ " %
              (idx + 1, len(datasets)))
        if args.which_model.lower() == "carn":
            net = model.CARN(args.n_selected_img,
                             args.filters,
                             3,
                             s_MSE=args.s_MSE,
                             trellis=args.trellis)
        elif args.which_model.lower() == "rdn":
            net = model.RDN(args.n_selected_img,
                            3,
                            3,
                            filters=args.filters,
                            s_MSE=args.s_MSE,
                            group=args.n_selected_img,
                            trellis=args.trellis)
        elif args.which_model.lower() == "meta_rdn":
            net = RDN_Meta(args.n_selected_img,
                           filters=args.filters,
                           scale=3,
                           s_MSE=args.s_MSE,
                           group=args.n_selected_img,
                           trellis=args.trellis)
        elif args.which_model.lower() == "basic":
            net = model.ProbaV_basic(inchannel=args.n_selected_img)
        else:
            print(
                "args.which_model or -wm should be one of [carn, rdn, basic], "
                "your -wm %s is illegal, and switched to 'basic' automatically"
                % (args.which_model.lower()))
            net = model.ProbaV_basic(inchannel=args.n_selected_img)
        net.apply(init_cnn)
        if args.half_precision:
            net.half()
        net = torch.nn.DataParallel(net,
                                    device_ids=args.gpu_id,
                                    output_device=args.output_gpu_id).cuda()
        torch.backends.cudnn.benchmark = True
        if args.finetune:
            net = util.load_latest_model(args,
                                         net,
                                         prefix=args.model_prefix_finetune,
                                         strict=True)
        optimizer = AdaBound(net.parameters(),
                             lr=args.learning_rate,
                             final_lr=10 * args.learning_rate,
                             weight_decay=args.weight_decay)
        #criterion = ListedLoss(type="l1", reduction="mean")
        #criterion = torch.nn.DataParallel(criterion, device_ids=args.gpu_id, output_device=args.output_gpu_id).cuda()
        measure = MultiMeasure(type="l2",
                               reduction="mean",
                               half_precision=args.half_precision)
        #measure = torch.nn.DataParallel(measure, device_ids=args.gpu_id, output_device=args.output_gpu_id).cuda()
        for epoch in range(args.epoch_num):
            _l, _m = fit(args,
                         net,
                         train_set,
                         optimizer,
                         measure,
                         is_train=True)
            Loss.append(_l)
            Measure.append(_m)
            if val_set is not None:
                _vl, _vm = val(args, net, val_set, optimizer, measure)
                val_Loss.append(_vl)
                val_Measure.append(_vm)

            if (epoch + 1) % 10 == 0:
                util.save_model(args,
                                args.curr_epoch,
                                net.state_dict(),
                                prefix=args.model_prefix,
                                keep_latest=10)
            if (epoch + 1) > 5:
                vb.plot_multi_loss_distribution(
                    multi_line_data=[
                        to_array(Loss) + to_array(val_Loss),
                        to_array(Measure) + to_array(val_Measure)
                    ],
                    multi_line_labels=[[
                        "train_mae", "train_smse", "val_mae", "val_smse"
                    ], [
                        "train_PSNR",
                        "train_L1",
                        "val_PSNR",
                        "val_L1",
                    ]],
                    save_path=args.loss_log,
                    window=3,
                    name=dt + "cv_%d" % (idx + 1),
                    bound=[{
                        "low": 0.0,
                        "high": 15
                    }, {
                        "low": 10,
                        "high": 50
                    }],
                    titles=["Loss", "Measure"])
        # Clean the data for next cross validation
        del net, optimizer, measure
        args.curr_epoch = 0
Exemplo n.º 9
0
def test_rotation(opt):
    result_dir = os.path.join(args.path, args.code_name,
                              "result+" + "-".join(opt.model_prefix_list))
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    # Load
    if type(opt.model_prefix_list) is str:
        opt.model_prefix_list = [opt.model_prefix_list]
    assert len(opt.model_prefix_list) <= torch.cuda.device_count(), \
        "number of models should not exceed the device numbers"
    nets = []
    for _, prefix in enumerate(opt.model_prefix_list):
        #net = model.SSD(cfg, connect_loc_to_conf=True, fix_size=False,
        #conf_incep=True, loc_incep=True)
        net = model.SSD(cfg,
                        connect_loc_to_conf=opt.loc_to_conf,
                        fix_size=False,
                        conf_incep=opt.conf_incep,
                        loc_incep=opt.loc_incep,
                        loc_preconv=opt.loc_preconv,
                        conf_preconv=opt.conf_preconv,
                        FPN=opt.feature_pyramid_net,
                        SA=opt.self_attention,
                        in_wid=opt.inner_filters,
                        m_factor=opt.inner_m_factor)
        device_id = opt.device_id if len(opt.model_prefix_list) == 1 else _
        net = net.to("cuda:%d" % (device_id))
        net_dict = net.state_dict()
        weight_dict = util.load_latest_model(args,
                                             net,
                                             prefix=prefix,
                                             return_state_dict=True,
                                             nth=opt.nth_best_model)
        loading_fail_signal = False
        for i, key in enumerate(net_dict.keys()):
            if "module." + key not in weight_dict:
                net_dict[key] = torch.zeros(net_dict[key].shape)
        for key in weight_dict.keys():
            if key[7:] in net_dict:
                if net_dict[key[7:]].shape == weight_dict[key].shape:
                    net_dict[key[7:]] = weight_dict[key]
                else:
                    print(
                        "Key: %s from disk has shape %s copy to the model with shape %s"
                        % (key[7:], str(weight_dict[key].shape),
                           str(net_dict[key[7:]].shape)))
                    loading_fail_signal = True
            else:
                print("Key: %s does not exist in net_dict" % (key[7:]))
        if loading_fail_signal:
            raise RuntimeError(
                'Shape Error happens, remove "%s" from your -mpl settings.' %
                (prefix))

        net.load_state_dict(net_dict)
        net.eval()
        nets.append(net)
        print("Above model loaded with out a problem")
    detector = model.Detect(num_classes=2,
                            bkg_label=0,
                            top_k=opt.detector_top_k,
                            conf_thresh=opt.detector_conf_threshold,
                            nms_thresh=opt.detector_nms_threshold)

    # Enumerate test folder
    root_path = os.path.expanduser(opt.test_dataset_root)
    if not os.path.exists(root_path):
        raise FileNotFoundError(
            "%s does not exists, please check your -tdr/--test_dataset_root settings"
            % (root_path))
    img_list = glob.glob(root_path + "/*.%s" % (opt.extension))
    precisions, recalls = [], []
    for i, img_file in enumerate(sorted(img_list)):
        start = time.time()
        name = img_file[img_file.rfind("/") + 1:-4]
        img = cv2.imread(img_file)
        height_ori, width_ori = img.shape[0], img.shape[1]

        do_rotation = False
        if do_rotation:
            # detect rotation for returning the image back
            img, transform_det = estimate_angle(img, args, None, None, None)
            transform_det["rotation"] = 0
            if transform_det["rotation"] != 0:
                rot_aug = augmenters.Affine(rotate=transform_det["rotation"],
                                            cval=args.aug_bg_color)
            else:
                rot_aug = None

            # Perform Augmentation
            if rot_aug:
                rot_aug = augmenters.Sequential(
                    augmenters.Affine(rotate=transform_det["rotation"],
                                      cval=args.aug_bg_color))
                image = rot_aug.augment_image(img)
            else:
                image = img
        else:
            image = img
        # Resize the longer side to a certain length
        if height_ori >= width_ori:
            resize_aug = augmenters.Sequential([
                augmenters.Resize(size={
                    "height": square,
                    "width": "keep-aspect-ratio"
                })
            ])
        else:
            resize_aug = augmenters.Sequential([
                augmenters.Resize(size={
                    "height": "keep-aspect-ratio",
                    "width": square
                })
            ])
        resize_aug = resize_aug.to_deterministic()
        image = resize_aug.augment_image(image)
        h_re, w_re = image.shape[0], image.shape[1]
        # Pad the image into a square image
        pad_aug = augmenters.Sequential(
            augmenters.PadToFixedSize(width=square,
                                      height=square,
                                      pad_cval=255,
                                      position="center"))
        pad_aug = pad_aug.to_deterministic()
        image = pad_aug.augment_image(image)
        h_final, w_final = image.shape[0], image.shape[1]

        # Prepare image tensor and test
        image_t = torch.Tensor(util.normalize_image(args, image)).unsqueeze(0)
        image_t = image_t.permute(0, 3, 1, 2)
        # visualize_bbox(args, cfg, image, [torch.Tensor(rot_coord).cuda()], net.prior, height_final/width_final)

        text_boxes = []
        for _, net in enumerate(nets):
            device_id = opt.device_id if len(nets) == 1 else _
            image_t = image_t.to("cuda:%d" % (device_id))
            out = net(image_t, is_train=False)
            loc_data, conf_data, prior_data = out
            prior_data = prior_data.to("cuda:%d" % (device_id))
            det_result = detector(loc_data, conf_data, prior_data)
            # Extract the predicted bboxes
            idx = det_result.data[0, 1, :, 0] >= 0.1
            text_boxes.append(det_result.data[0, 1, idx, 1:])
        text_boxes = torch.cat(text_boxes, dim=0)
        text_boxes = combine_boxes(text_boxes, img=image_t)
        pred = [[float(coor) for coor in area] for area in text_boxes]
        BBox = [
            imgaug.augmentables.bbs.BoundingBox(box[0] * w_final,
                                                box[1] * h_final,
                                                box[2] * w_final,
                                                box[3] * h_final)
            for box in pred
        ]
        BBoxes = imgaug.augmentables.bbs.BoundingBoxesOnImage(
            BBox, shape=image.shape)
        return_aug = augment_back(height_ori, width_ori, (h_final - h_re) / 2,
                                  (w_final - w_re) / 2)
        return_aug = return_aug.to_deterministic()
        img_ori = return_aug.augment_image(image)
        bbox = return_aug.augment_bounding_boxes([BBoxes])[0]
        # print_box(blue_boxes=pred, idx=i, img=vb.plot_tensor(args, image_t, margin=0),
        # save_dir=args.val_log)

        f = open(os.path.join(result_dir, name + ".txt"), "w")
        gt_box_file = os.path.join(opt.test_dataset_root,
                                   name + "." + opt.ground_truth_extension)
        coords = tb_data.parse_file(os.path.expanduser(gt_box_file))
        gt_coords = []
        for coord in coords:
            x1, x2 = min(coord[::2]), max(coord[::2])
            y1, y2 = min(coord[1::2]), max(coord[1::2])
            gt_coords.append([x1, y1, x2, y2])
            cv2.rectangle(img, (x1, y1), (x2, y2), (70, 67, 238), 2)
        pred_final = []
        for box in bbox.bounding_boxes:
            x1, y1, x2, y2 = int(round(box.x1)), int(round(box.y1)), int(
                round(box.x2)), int(round(box.y2))
            pred_final.append([x1, y1, x2, y2])
            # box_tensors.append(torch.tensor([x1, y1, x2, y2]))
            # 4-point to 8-point: x1, y1, x2, y1, x2, y2, x1, y2
            f.write("%d,%d,%d,%d,%d,%d,%d,%d\n" %
                    (x1, y1, x2, y1, x2, y2, x1, y2))
            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 105, 65), 2)
        accu, precision, recall = measure(torch.Tensor(pred_final).cuda(),
                                          torch.Tensor(gt_coords).cuda(),
                                          width=img.shape[1],
                                          height=img.shape[0])
        precisions.append(precision)
        recalls.append(recall)
        img_save_directory = os.path.join(
            args.path, args.code_name,
            "val+" + "-".join(opt.model_prefix_list))
        if not os.path.exists(img_save_directory):
            os.mkdir(img_save_directory)
        name = "%s_%.2f_%.2f" % (str(i).zfill(4), precision, recall)
        cv2.imwrite(os.path.join(img_save_directory, name + ".jpg"), img)
        f.close()
        if opt.verbose:
            print(
                "%d th image cost %.2f seconds, precision: %.2f, recall: %.2f"
                % (i, time.time() - start, precision, recall))
    print("Precision: %.2f, Recall: %.2f" % (avg(precisions), avg(recalls)))