Ejemplo n.º 1
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    load_network(network,
                 cfg.model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    if args.vis_out:
        os.mkdir(args.vis_out)
    idx = 0
    #start = timeit.default_timer()
    for batch in tqdm.tqdm(data_loader):
        idx += 1
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        if args.vis_out:
            visualizer.visualize(
                output, batch,
                os.path.join(args.vis_out, "{:04d}_.png".format(idx)))
        else:
            visualizer.visualize(output, batch)
Ejemplo n.º 2
0
def run_demo():
    from lib.datasets import make_data_loader
    from lib.visualizers import make_visualizer
    import tqdm
    import torch
    from lib.networks import make_network
    from lib.utils.net_utils import load_network
    import glob
    from PIL import Image

    torch.manual_seed(0)
    meta = np.load(os.path.join(cfg.demo_path, 'meta.npy'),
                   allow_pickle=True).item()
    demo_images = glob.glob(cfg.demo_path + '/*jpg')

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    visualizer = make_visualizer(cfg)

    mean, std = np.array([0.485, 0.456,
                          0.406]), np.array([0.229, 0.224, 0.225])
    for demo_image in demo_images:
        demo_image = np.array(Image.open(demo_image)).astype(np.float32)
        inp = (((demo_image / 255.) - mean) / std).transpose(2, 0, 1).astype(
            np.float32)
        inp = torch.Tensor(inp[None]).cuda()
        with torch.no_grad():
            output = network(inp)
        visualizer.visualize_demo(output, inp, meta)
Ejemplo n.º 3
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    num = -1
    for batch in tqdm.tqdm(data_loader):
        num = num + 1
        name = '%06d.jpg' % num
        #high_resolution = '/mnt/SSD/jzwang/code/clean-pvnet/demo28/' + name
        #high_resolution = Image.open(high_resolution)
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        #visualizer.visualize(output, batch, name, high_resolution)
        visualizer.visualize(output, batch, name)
Ejemplo n.º 4
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    if DEBUG:
        print(
            '===================================Visualizing================================='
        )
        import pprint
        print('args:', args)
        print('cfg:')
        pprint.pprint(cfg)

    network = make_network(cfg).cuda()
    load_network(network,
                 cfg.model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        visualizer.visualize(output, batch)
Ejemplo n.º 5
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    from lib.utils import net_utils
    import tqdm
    import torch
    from lib.visualizers import make_visualizer
    from lib.networks.renderer import make_renderer

    cfg.perturb = 0

    network = make_network(cfg).cuda()
    load_network(network,
                 cfg.trained_model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.train()

    data_loader = make_data_loader(cfg, is_train=False)
    renderer = make_renderer(cfg, network)
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = renderer.render(batch)
            visualizer.visualize(output, batch)
Ejemplo n.º 6
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    from lib.networks import make_network
    from lib.utils.net_utils import load_network

    if DEBUG:
        print(
            '-------------------------------Evaluating---------------------------------'
        )

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)
    for batch in tqdm.tqdm(data_loader):
        inp = batch['inp'].cuda()
        with torch.no_grad():
            output = network(inp)
        evaluator.evaluate(output, batch)
    evaluator.summarize()
Ejemplo n.º 7
0
def run_network():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    import time

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    total_time = 0
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            torch.cuda.synchronize()
            start = time.time()
            network(batch['inp'])
            torch.cuda.synchronize()
            total_time += time.time() - start
    print(total_time / len(data_loader))
Ejemplo n.º 8
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    from lib.networks import make_network
    from lib.utils.net_utils import load_network
    from lib.train import make_trainer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    trainer = make_trainer(cfg, network)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    if 'Coco' in cfg.train.dataset:
        trainer.val_coco(data_loader)
    else:
        evaluator = make_evaluator(cfg)
        for batch in tqdm.tqdm(data_loader):
            inp = batch['inp'].cuda()
            with torch.no_grad():
                output = network(inp)
            evaluator.evaluate(output, batch)
        evaluator.summarize()
Ejemplo n.º 9
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    index = 0
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        # pdb.set_trace()
        visualizer.visualize(output, batch, index)

        index = index + 1
Ejemplo n.º 10
0
    def __init__(self):
        super(Network, self).__init__()

        det_meta = cfg.det_meta
        self.dla_ct = get_pose_net(det_meta.num_layers, det_meta.heads)
        self.pvnet = get_res_pvnet(cfg.heads['vote_dim'], cfg.heads['seg_dim'])

        net_utils.load_network(self.dla_ct, cfg.det_model)
        net_utils.load_network(self.pvnet, cfg.kpt_model)
Ejemplo n.º 11
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.visualizers import make_visualizer

    network = make_network(cfg).cuda()
    print("model dir:{}".format(cfg.model_dir))
    load_network(network,
                 cfg.model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    visualizer = make_visualizer(cfg)
    ann_file = 'data/NICE1/NICE1/coco_int/test/annotations/NICE_test.json'  # 这里是Test的json,因为这里使用的就是test的集合
    coco = COCO(ann_file)
    for batch in tqdm.tqdm(data_loader):  #tqdm是个进度条
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        # print("batch  ->{}".format(batch)) #batch是对的,里面还有img_id等相关信息
        # print("batch['meta']  ->{}".format(batch['meta'])) #batch['meta']是对的
        img_info = batch['meta']
        img_id = img_info['img_id']  #这个img_id用来根据id读取img名称从而确定输出图片名称的
        img_scale = img_info['scale']  # 这个img_scale是用来读取尺寸从而改变图像尺寸的
        # print("img id ->{}".format(img_id))
        # print("img scale ->{}".format(img_scale))
        img_name = coco.loadImgs(int(img_id))[0]['file_name']
        img_name, _ = os.path.splitext(img_name)
        # ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=0)  #两种方法,一种获得anno,一种直接loadimg获得filename
        # anno = coco.loadAnns(ann_ids)
        visualizer.visualize(output, batch, img_name)
        tmp_file = open(
            '/home/tianhao.lu/code/Deep_snake/snake/Result/Contour/Output_check.log',
            'w',
            encoding='utf8')
        tmp_file.writelines("Output -> :" + str(output) + "\n")
        tmp_file.writelines("batch -> :" + str(batch) + "\n")
        # for tmp_data in train_loader:
        #     tmp_file.writelines("one train_loader data type:" + str(type(tmp_data)) + "\n")
        #     for key in tmp_data:
        #         tmp_file.writelines("one train_loader data key:" + str(key) + "\n")
        #         tmp_file.writelines("one train_loader data len:" + str(len(tmp_data[key])) + "\n")
        #     # tmp_file.writelines("one train_loader data:" + str(tmp_data) + "\n")
        #     break
        tmp_file.writelines(
            str("*************************************************************** \n"
                ))
        tmp_file.close()
Ejemplo n.º 12
0
def run_visualize():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    from lib.utils import net_utils
    import tqdm
    import torch
    from lib.visualizers import make_visualizer
    from lib.networks.renderer import make_renderer
    import os

    cfg.perturb = 0
    cfg.render_name = args.render_name

    network = make_network(cfg)#.cuda()
    load_network(network,
                 cfg.trained_model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.train()

    data_loader = make_data_loader(cfg, is_train=False)
    n_devices = 2

    expand_k = ['index', 'bounds', 'R', 'Th',
                'center', 'rot', 'trans', 'i',
                'cam_ind', 'feature', 'coord', 'out_sh']
    renderer = torch.nn.DataParallel(make_renderer(cfg, network)).cuda()
    #renderer.set_save_mem(True)
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(data_loader):
        batch_cuda = {}
        for k in batch:
            if k != 'meta':
                # TODO: ugly hack...
                if k not in expand_k:
                    sh = batch[k].shape
                    if sh[-1] == 3:
                        batch_cuda[k] = batch[k].view(-1, 3)
                    else:
                        batch_cuda[k] = batch[k].view(-1)
                else:
                    sh = batch[k].shape
                    batch_cuda[k] = batch[k].expand(n_devices, *sh[1:])

                batch_cuda[k] = batch_cuda[k].cuda()
            else:
                batch_cuda[k] = batch[k]


        with torch.no_grad():
            output = renderer(batch_cuda)
            visualizer.visualize(output, batch)
Ejemplo n.º 13
0
def demo():
    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    network.eval()

    dataset = Dataset()
    visualizer = make_visualizer(cfg)
    for batch in tqdm.tqdm(dataset):
        batch['inp'] = torch.FloatTensor(batch['inp'])[None].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
            # print('Output:{} /n Output.len:{}'.format(output, len(output)))  #得到这个output是个tensor
        visualizer.visualize(output, batch)
Ejemplo n.º 14
0
    def __init__(self,
                 num_layers,
                 heads,
                 head_conv=256,
                 down_ratio=4,
                 det_dir=''):
        super(Network, self).__init__()

        self.dla = DLASeg('dla{}'.format(num_layers),
                          heads,
                          pretrained=True,
                          down_ratio=down_ratio,
                          final_kernel=1,
                          last_level=5,
                          head_conv=head_conv)
        self.gcn = Evolution()

        net_utils.load_network(self.dla, det_dir)
Ejemplo n.º 15
0
def test(cfg, network):
    trainer = make_trainer(cfg, network)
    val_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)
    epoch = load_network(network,
                         cfg.model_dir,
                         resume=cfg.resume,
                         epoch=cfg.test.epoch)
    trainer.val(epoch, val_loader, evaluator)
    def __init__(self,
                 num_layers,
                 heads,
                 head_conv=256,
                 down_ratio=4,
                 det_dir=''):
        super(Network, self).__init__()

        self.dla = DLASeg('dla{}'.format(num_layers),
                          heads,
                          pretrained=True,
                          down_ratio=down_ratio,
                          final_kernel=1,
                          last_level=5,
                          head_conv=head_conv)
        self.cp = ComponentDetection()
        self.gcn = Evolution()

        det_dir = os.path.join(os.path.dirname(cfg.model_dir), cfg.det_model)
        net_utils.load_network(self, det_dir, strict=False)
Ejemplo n.º 17
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    from lib.networks import make_network
    from lib.utils.net_utils import load_network

    import numpy as np
    np.random.seed(1000)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    data_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)

    count = 1200
    i = 0
    for batch in tqdm.tqdm(data_loader):
        if i == count:
            break
        inp = batch['inp'].cuda()
        # save input
        #         print(batch['img_id'])
        #         import pickle
        #         with open('/mbrdi/sqnap1_colomirror/gupansh/input_cat.pkl','wb') as fp:
        #             pickle.dump(batch['inp'], fp)
        #         input()
        seg_gt = batch['mask'].cuda()
        with torch.no_grad():
            output = network(inp)
        evaluator.evaluate(output, batch)

        i += 1
    evaluator.summarize()
Ejemplo n.º 18
0
def run_analyze():
    from lib.networks import make_network
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_network
    import tqdm
    import torch
    from lib.analyzers import make_analyzer

    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()

    cfg.train.num_workers = 0
    data_loader = make_data_loader(cfg, is_train=False)
    analyzer = make_analyzer(cfg)
    for batch in tqdm.tqdm(data_loader):
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].cuda()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        analyzer.analyze(output, batch)
Ejemplo n.º 19
0
def run_evaluate():
    from lib.datasets import make_data_loader
    from lib.evaluators import make_evaluator
    import tqdm
    import torch
    import pickle
    import lzma
    from lib.networks import make_network
    from lib.utils.net_utils import load_network
    from lib.evaluators.custom.monitor import MetricMonitor
    from lib.visualizers import make_visualizer

    torch.manual_seed(0)

    monitor = MetricMonitor()
    network = make_network(cfg).cuda()
    epoch = load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
    network.eval()
    print("Trainable parameters: {}".format(
        sum(p.numel() for p in network.parameters())))

    data_loader = make_data_loader(cfg, is_train=False)
    evaluator = make_evaluator(cfg)
    visualizer = make_visualizer(cfg)
    idx = 0

    if args.vis_out:
        os.mkdir(args.vis_out)

    for batch in tqdm.tqdm(data_loader):
        idx += 1
        inp = batch['inp'].cuda()
        with torch.no_grad():
            output = network(inp)
        evaluator.evaluate(output, batch)

        if args.vis_out:
            err = evaluator.data["obj_drilltip_trans_3d"][-1]
            visualizer.visualize(
                output, batch,
                os.path.join(
                    args.vis_out,
                    "tiperr{:.4f}_idx{:04d}.png".format(err.item(), idx)))
    result = evaluator.summarize()
    monitor.add('val', epoch, result)
    monitor.save_metrics("metrics.pkl")
    monitor.plot_histogram("evaluation.html", plotly=True)
Ejemplo n.º 20
0
def train(cfg, network):
    trainer = make_trainer(cfg, network)
    optimizer = make_optimizer(cfg, network)
    scheduler = make_lr_scheduler(cfg, optimizer)
    recorder = make_recorder(cfg)
    #evaluator = make_evaluator(cfg)

    if cfg.network_full_init:
        begin_epoch = load_network(network,
                                   cfg.model_dir,
                                   resume=cfg.resume,
                                   epoch=cfg.test.epoch)
        begin_epoch = 0
    else:
        begin_epoch = load_model(network,
                                 optimizer,
                                 scheduler,
                                 recorder,
                                 cfg.model_dir,
                                 resume=cfg.resume)

    # set_lr_scheduler(cfg, scheduler)
    if DEBUG:
        print('------------------Loading training set-------------------')
    train_loader = make_data_loader(cfg, is_train=True)
    if DEBUG:
        print('Loading training set done...')
        print('---------------------------------------------------------')
    #val_loader = make_data_loader(cfg, is_train=False)

    for epoch in range(begin_epoch, cfg.train.epoch):
        recorder.epoch = epoch
        trainer.train(epoch, train_loader, optimizer, recorder)
        scheduler.step()

        if (epoch + 1) % cfg.save_ep == 0:
            save_model(network, optimizer, scheduler, recorder, epoch,
                       cfg.model_dir)

        #if (epoch + 1) % cfg.eval_ep == 0:
        #    trainer.val(epoch, val_loader, evaluator, recorder)

    return network
Ejemplo n.º 21
0
def inference():
    network = make_network(cfg).cuda()
    load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
    network.eval()

    with open(os.path.join(cfg.results_dir,'cfg.json'),'w') as fid:
        json.dump(cfg,fid)

    dataset = Dataset()
    visualizer = make_visualizer(cfg)
    infer_time_lst = []
    for batch in tqdm.tqdm(dataset):
        batch['inp'] = torch.FloatTensor(batch['inp'])[None].cuda()
        net_time_s = time.time()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        net_used_time = time.time()-net_time_s

        org_img = batch['org_img']
        rz_img = batch['rz_img']
        rz_ratio = batch['rz_ratio']
        img_name = batch['image_name']
        center = batch['meta']['center']
        scale = batch['meta']['scale']
        h, w = batch['inp'].size(2), batch['inp'].size(3)

        if DEBUG:
            print('------------------img_name={}-------------------------'.format(img_name))
            print('org_img.shape:', org_img.shape)
            print('rz_img.shape:',  rz_img.shape)
            print('input-size:({}, {})'.format(h,w))
        
        if cfg.rescore_map_flag:
            rs_thresh = 0.6
            detections = output['detection'].detach().cpu().numpy()
            polys = output['py'][-1].detach().cpu().numpy()
            rs_hm = torch.sigmoid(output['rs_hm']).detach().cpu().numpy()
            if 0:
                print('output.keys:', output.keys())

            rescores = rescoring_polygons(polys, rs_hm)
            conf_keep = np.where(rescores > rs_thresh)[0]
            
            detections = detections[conf_keep]
            pys = [polys[k]* snake_config.down_ratio for k in conf_keep]
            rescores = rescores[conf_keep]        
            
            rs_hm_path = os.path.join(cfg.vis_dir,(img_name[:-4]+'_rs.png'))
            import matplotlib.pyplot as plt
            plt.imshow(rs_hm[0,0,...])
            plt.savefig(rs_hm_path)
            if 0:
                print('detections.shape:', detections.shape)
                print('pys.num:', len(pys))
                print('rs_hm.shape:', rs_hm.shape)
                x = rs_hm[0,0,...]
                import matplotlib.pyplot as plt 
                plt.imshow(x)
                for k in range(len(pys)):
                    plt.plot(pys[k][:,0], pys[k][:, 1])
                plt.savefig('{}.png'.format(img_name[:-4]))
                plt.close()
                np.save('rs_hm.npy', x)
                np.save('pys.npy', np.array(pys))
                exit()
        else:
            detections = output['detection'].detach().cpu().numpy()
            detections[:,:4] = detections[:, :4] * snake_config.down_ratio
            bboxes = detections[:, :4]
            scores = detections[:, 4]
            labels = detections[:, 5].astype(int)
            ex_pts = output['ex'].detach().cpu().numpy()
            ex_pts = ex_pts * snake_config.down_ratio
            #pys = output['py'][-1].detach().cpu().numpy() * snake_config.down_ratio
            iter_ply_output_lst = [x.detach().cpu().numpy()* snake_config.down_ratio for x in output['py']]
            pys = iter_ply_output_lst[-1]

            if cfg.vis_intermediate_output != 'none':
                if cfg.vis_intermediate_output == 'htp':
                    xmin,ymin,xmax,ymax = bboxes[:,0::4], bboxes[:,1::4], bboxes[:, 2::4], bboxes[:,3::4]
                    pys = np.hstack((xmin,ymin, xmin,ymax,xmax,ymax,xmax,ymin))
                    pys = pys.reshape(pys.shape[0],4,2)
                elif cfg.vis_intermediate_output == 'otp':
                    pys = ex_pts
                elif cfg.vis_intermediate_output == 'clm_1':
                    pys = iter_ply_output_lst[0]
                elif cfg.vis_intermediate_output == 'clm_2':
                    pys = iter_ply_output_lst[1]
                else:
                    raise ValueError('Not supported type:', cfg.vis_intermediate_output)
                cfg.poly_cls_branch = False


            final_contour_feat = output['final_feat'].detach().cpu().numpy()
            if cfg.poly_cls_branch:
                pys_cls = output['py_cls'][-1].detach().cpu().numpy()
                text_poly_scores = pys_cls[:, 1]
                rem_ids = np.where(text_poly_scores > cfg.poly_conf_thresh)[0]
                detections = detections[rem_ids]
                pys = pys[rem_ids]
                text_poly_scores = text_poly_scores[rem_ids]
                ex_pts = ex_pts[rem_ids]
                final_contour_feat = final_contour_feat[rem_ids]
                if DEBUG:
                    print('py_cls_scores:', text_poly_scores)

            if DEBUG:
                print('dets_num:', len(pys))

        if len(pys) == 0:
            all_boundaries, poly_scores = [], []
        else:
            trans_output_inv = data_utils.get_affine_transform(center, scale, 0, [w, h], inv=1)
            all_boundaries   = [data_utils.affine_transform(py_, trans_output_inv) for py_ in pys]
            bboxes_tmp = [data_utils.affine_transform(det[:4].reshape(-1,2), trans_output_inv).flatten() for det in detections]
            ex_pts_tmp = [data_utils.affine_transform(ep, trans_output_inv) for ep in ex_pts]
            detections = np.hstack((np.array(bboxes_tmp), detections[:,4:]))
            ex_pts = np.array(ex_pts_tmp)

            pp_time_s = time.time()
            #sorting detections by scores
            if cfg.poly_cls_branch:
                detections, ex_points, all_boundaries, final_contour_feat, poly_scores \
                  = sorting_det_results(detections, ex_pts, all_boundaries, final_contour_feat, text_poly_scores)
            else:
                detections, ex_points, all_boundaries = sorting_det_results(detections, ex_pts, all_boundaries)
            
            if len(all_boundaries) != 0:
                detections[:,:4] /= rz_ratio
                ex_points /= rz_ratio
                all_boundaries = [poly/rz_ratio for poly in all_boundaries]
                
            if 0:
                import matplotlib.pyplot as plt
                nms_polygons,rem_inds = snake_poly_utils.poly_nms(all_boundaries)
                print('nms_polygons.num:', len(nms_polygons))
                plt.subplot(1,2,1)
                plt = plot_poly(org_img, all_boundaries,scores=scores)
                plt.subplot(1,2,2)
                plt = plot_poly(org_img, nms_polygons)
                plt.savefig('a.png')
                exit()
            
            #nms
            all_boundaries, rem_inds = snake_poly_utils.poly_nms(all_boundaries)
            detections = detections[rem_inds]
            ex_points = ex_points[rem_inds]
            final_contour_feat = final_contour_feat[rem_inds]
            if cfg.poly_cls_branch:
                poly_scores = poly_scores[rem_inds]
            pp_used_time = time.time() - pp_time_s
            infer_time_lst.append([net_used_time, pp_used_time])
            if DEBUG:
                print('infer_time:',[net_used_time, pp_used_time])

            if 0:
                vis_tmp_results(org_img, detections, ex_points, all_boundaries, final_contour_feat, poly_scores, output, indx=img_name[:-4])

        #--------------------------------saving results-------------------------------#
        if cfg.testing_set == 'mlt':
            det_file = os.path.join(cfg.det_dir, ('res_'+img_name[3:-4]+'.txt'))
            saving_mot_det_results(det_file, all_boundaries, testing_set=cfg.testing_set, img=org_img)
        elif cfg.testing_set == 'ic15':
            det_file = os.path.join(cfg.det_dir, ('res_'+img_name[:-4]+'.txt'))
            saving_mot_det_results(det_file, all_boundaries, testing_set=cfg.testing_set, img=org_img)
        elif cfg.testing_set == 'msra':
            det_file = os.path.join(cfg.det_dir, ('res_'+img_name[:-4]+'.txt'))
            saving_mot_det_results(det_file, all_boundaries, testing_set=cfg.testing_set, img=org_img)
        else: #for arbitrary-shape datasets, e.g., CTW,TOT,ART
            det_file = os.path.join(cfg.det_dir, (img_name[:-4]+'.txt'))
            saving_det_results(det_file, all_boundaries, img=org_img)
        
        continue        
        #------------------------visualizing results---------------------------------#
        ## ~~~~~~ vis-v0 ~~~~~~~ ##
        vis_file = os.path.join(cfg.vis_dir,(img_name[:-4]+'.png'))
        if cfg.testing_set == 'ctw':
            gt_file = os.path.join(cfg.gts_dir, (img_name[:-4]+'.txt'))
            gt_polys = load_ctw_gt_label(gt_file)
        elif cfg.testing_set == 'tot':
            gt_file = os.path.join(cfg.gts_dir, ('poly_gt_'+img_name[:-4]+'.mat'))
            gt_polys = load_tot_gt_label(gt_file)
        elif cfg.testing_set == 'art':
            gt_polys = None
        elif cfg.testing_set == 'msra':
            gt_file = os.path.join(cfg.gts_dir, ('gt_'+img_name[:-4]+'.txt'))
            gt_polys = load_msra_gt_label(gt_file)
        else:
            raise ValueError('Not supported dataset ({}) for visualizing'.format(cfg.testing_set))
        plt = vis_dets_gts(org_img, all_boundaries, gt_polys)
        plt.savefig(vis_file,dpi=600,format='png')
        plt.close()
        ### ~~~~~~~~~ vis-v1 ~~~~~~~~~~~ ###
        # if cfg.poly_cls_branch:
        #     visualizing_det_results(org_img,all_boundaries,vis_file, scores=detections[:,4],poly_scores=poly_scores)
        # else:
        #     visualizing_det_results(org_img,all_boundaries,vis_file, scores=detections[:,4])
        ## vis-v2
        #hm_vis_dir = os.path.join(cfg.vis_dir, ('../vis_hm_on_img_dir'))
        #if not os.path.exists(hm_vis_dir):
        #    os.makedirs(hm_vis_dir)
        #visualizer.visualize(output, batch, os.path.join(hm_vis_dir,(img_name[:-4]+'.png')))

    np.save('infer_time.npy', np.array(infer_time_lst))
Ejemplo n.º 22
0
def inference():
    network = make_network(cfg).cuda()
    load_network(network,
                 cfg.model_dir,
                 resume=cfg.resume,
                 epoch=cfg.test.epoch)
    network.eval()

    with open(os.path.join(cfg.results_dir, 'cfg.json'), 'w') as fid:
        json.dump(cfg, fid)

    dataset = Dataset()
    visualizer = make_visualizer(cfg)
    infer_time_lst = []
    for batch in tqdm.tqdm(dataset):
        batch['inp'] = torch.FloatTensor(batch['inp'])[None].cuda()
        net_time_s = time.time()
        with torch.no_grad():
            output = network(batch['inp'], batch)
        net_used_time = time.time() - net_time_s

        org_img = batch['org_img']
        rz_img = batch['rz_img']
        rz_ratio = batch['rz_ratio']
        img_name = batch['image_name']
        center = batch['meta']['center']
        scale = batch['meta']['scale']
        h, w = batch['inp'].size(2), batch['inp'].size(3)

        detections = output['detection'].detach().cpu().numpy()
        detections[:, :4] = detections[:, :4] * snake_config.down_ratio
        bboxes = detections[:, :4]
        scores = detections[:, 4]
        labels = detections[:, 5].astype(int)
        ex_pts = output['ex'].detach().cpu().numpy()
        ex_pts = ex_pts * snake_config.down_ratio
        #pys = output['py'][-1].detach().cpu().numpy() * snake_config.down_ratio
        iter_ply_output_lst = [
            x.detach().cpu().numpy() * snake_config.down_ratio
            for x in output['py']
        ]
        pys = iter_ply_output_lst[-1]

        final_contour_feat = output['final_feat'].detach().cpu().numpy()
        if cfg.poly_cls_branch:
            pys_cls = output['py_cls'][-1].detach().cpu().numpy()
            text_poly_scores = pys_cls[:, 1]
            rem_ids = np.where(text_poly_scores > cfg.poly_conf_thresh)[0]
            detections = detections[rem_ids]
            pys = pys[rem_ids]
            text_poly_scores = text_poly_scores[rem_ids]
            ex_pts = ex_pts[rem_ids]
            final_contour_feat = final_contour_feat[rem_ids]

        if len(pys) == 0:
            all_boundaries, poly_scores = [], []
        else:
            trans_output_inv = data_utils.get_affine_transform(center,
                                                               scale,
                                                               0, [w, h],
                                                               inv=1)
            all_boundaries = [
                data_utils.affine_transform(py_, trans_output_inv)
                for py_ in pys
            ]
            bboxes_tmp = [
                data_utils.affine_transform(det[:4].reshape(-1, 2),
                                            trans_output_inv).flatten()
                for det in detections
            ]
            ex_pts_tmp = [
                data_utils.affine_transform(ep, trans_output_inv)
                for ep in ex_pts
            ]
            detections = np.hstack((np.array(bboxes_tmp), detections[:, 4:]))
            ex_pts = np.array(ex_pts_tmp)

            pp_time_s = time.time()
            #sorting detections by scores
            if cfg.poly_cls_branch:
                detections, ex_points, all_boundaries, final_contour_feat, poly_scores \
                  = sorting_det_results(detections, ex_pts, all_boundaries, final_contour_feat, text_poly_scores)
            else:
                detections, ex_points, all_boundaries = sorting_det_results(
                    detections, ex_pts, all_boundaries)

            if cfg.rle_nms:
                tmp_polys = all_boundaries.copy()
                #all_boundaries, rem_inds = snake_poly_utils.poly_nms(tmp_polys)
                rem_inds = poly_rle_nms(tmp_polys,
                                        detections[:, -1], (h, w),
                                        nms_thresh=0.3)
                all_boundaries = [all_boundaries[idx] for idx in rem_inds]
            else:
                #nms
                all_boundaries, rem_inds = snake_poly_utils.poly_nms(
                    all_boundaries)
            detections = detections[rem_inds]
            ex_points = ex_points[rem_inds]
            final_contour_feat = final_contour_feat[rem_inds]
            if cfg.poly_cls_branch:
                poly_scores = poly_scores[rem_inds]
            pp_used_time = time.time() - pp_time_s
            infer_time_lst.append([net_used_time, pp_used_time])

            if len(all_boundaries) != 0:
                detections[:, :4] /= rz_ratio
                ex_points /= rz_ratio
                all_boundaries = [poly / rz_ratio for poly in all_boundaries]

        #--------------------------------saving results-------------------------------#
        det_file = os.path.join(cfg.det_dir, (img_name[:-4] + '.txt'))
        saving_det_results(det_file, all_boundaries, img=org_img)