Beispiel #1
0
def main():
    best_mAP = 0.0
    saver = create_saver(cfg.local_rank, save_dir=cfg.ckpt_dir)
    logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir)
    summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir)

    print_log = logger.info
    print_log(cfg)

    torch.manual_seed(319)
    torch.backends.cudnn.benchmark = True  # disable this if OOM at beginning of training

    num_gpus = torch.cuda.device_count()
    if cfg.dist:
        cfg.device = torch.device('cuda:%d' % cfg.local_rank)
        torch.cuda.set_device(cfg.local_rank)
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=num_gpus,
                                rank=cfg.local_rank)
    else:
        cfg.device = torch.device('cuda:%d' % cfg.device_id)

    print_log('Setting up data...')
    dictionary = np.load(cfg.dictionary_file)
    Dataset = COCOSEGMCMM if cfg.dataset == 'coco' else KINSSEGMCMM
    train_dataset = Dataset(cfg.data_dir,
                            cfg.dictionary_file,
                            'train',
                            split_ratio=cfg.split_ratio,
                            img_size=cfg.img_size)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=num_gpus, rank=cfg.local_rank)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size // num_gpus if cfg.dist else cfg.batch_size,
        shuffle=not cfg.dist,
        num_workers=cfg.num_workers,
        pin_memory=False,
        drop_last=True,
        sampler=train_sampler if cfg.dist else None)

    Dataset_eval = COCO_eval_segm_cmm if cfg.dataset == 'coco' else KINS_eval_segm_cmm
    val_dataset = Dataset_eval(cfg.data_dir,
                               cfg.dictionary_file,
                               'val',
                               test_scales=[1.],
                               test_flip=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=False,
                                             collate_fn=val_dataset.collate_fn)

    print_log('Creating model...')
    if 'hourglass' in cfg.arch:
        # model = get_hourglass[cfg.arch]
        model = exkp(n=5,
                     nstack=2,
                     dims=[256, 256, 384, 384, 384, 512],
                     modules=[2, 2, 2, 2, 2, 4],
                     dictionary=torch.from_numpy(dictionary.astype(
                         np.float32)).to(cfg.device))
    elif 'resdcn' in cfg.arch:
        model = get_pose_resdcn(num_layers=int(cfg.arch.split('_')[-1]),
                                head_conv=64,
                                num_classes=train_dataset.num_classes)
    else:
        raise NotImplementedError

    if cfg.dist:
        # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = model.to(cfg.device)
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[
                cfg.local_rank,
            ], output_device=cfg.local_rank)
    else:
        model = nn.DataParallel(model, device_ids=[
            cfg.local_rank,
        ]).to(cfg.device)

    if cfg.pretrain_checkpoint is not None and os.path.isfile(
            cfg.pretrain_checkpoint):
        print_log('Load pretrain model from ' + cfg.pretrain_checkpoint)
        model = load_model(model, cfg.pretrain_checkpoint, cfg.device_id)
        torch.cuda.empty_cache()

    optimizer = torch.optim.Adam(model.parameters(), cfg.lr)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.lr_step,
                                                        gamma=0.1)

    def train(epoch):
        print_log('\n Epoch: %d' % epoch)
        model.train()
        tic = time.perf_counter()
        for batch_idx, batch in enumerate(train_loader):
            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=cfg.device,
                                           non_blocking=True)

            dict_tensor = torch.from_numpy(dictionary.astype(np.float32)).to(
                cfg.device, non_blocking=True)
            dict_tensor.requires_grad = False

            outputs = model(batch['image'])
            hmap, regs, w_h_, codes_1, codes_2, codes_3, offsets = zip(
                *outputs)

            regs = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in regs
            ]
            w_h_ = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
            ]
            c_1 = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in codes_1
            ]
            c_2 = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in codes_2
            ]
            c_3 = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in codes_3
            ]
            offsets = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in offsets
            ]

            # shapes_1 = [torch.matmul(c, dict_tensor) for c in c_1]
            # shapes_2 = [torch.matmul(c, dict_tensor) for c in c_2]
            # shapes_3 = [torch.matmul(c, dict_tensor) for c in c_3]

            hmap_loss = _neg_loss(hmap, batch['hmap'])
            # occ_loss = _neg_loss(occ_map, batch['occ_map'], ex=4.0)
            reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
            w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])
            offsets_loss = _reg_loss(offsets, batch['offsets'],
                                     batch['ind_masks'])
            codes_loss = (
                norm_reg_loss(c_1, batch['codes'], batch['ind_masks']) +
                norm_reg_loss(c_2, batch['codes'], batch['ind_masks']) +
                norm_reg_loss(c_3, batch['codes'], batch['ind_masks'])) / 3.

            # cmm_loss = (contour_mapping_loss(c_1, shapes_1, batch['shapes'], batch['ind_masks'], roll=False)
            #             + contour_mapping_loss(c_2, shapes_2, batch['shapes'], batch['ind_masks'], roll=False)
            #             + contour_mapping_loss(c_3, shapes_3, batch['shapes'], batch['ind_masks'], roll=False)) / 3.
            # cmm_loss = (_reg_loss(shapes_1, batch['shapes'], batch['ind_masks'])
            #             + _reg_loss(shapes_2, batch['shapes'], batch['ind_masks'])
            #             + _reg_loss(shapes_3, batch['shapes'], batch['ind_masks'])) / 3.

            # loss = 1. * hmap_loss + 1 * reg_loss + 0.1 * w_h_loss + cfg.cmm_loss_weight * cmm_loss \
            #        + cfg.code_loss_weight * codes_loss + 0.1 * offsets_loss
            loss = 1 * hmap_loss + 1 * reg_loss + 0.1 * w_h_loss \
                   + cfg.code_loss_weight * codes_loss + 0.1 * offsets_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % cfg.log_interval == 0:
                duration = time.perf_counter() - tic
                tic = time.perf_counter()
                print_log(
                    '[%d/%d-%d/%d] ' %
                    (epoch, cfg.num_epochs, batch_idx, len(train_loader)) +
                    'Loss: hmap = %.3f reg = %.3f w_h = %.3f code = %.3f offsets = %.3f'
                    % (hmap_loss.item(), reg_loss.item(), w_h_loss.item(),
                       codes_loss.item(), offsets_loss.item()) +
                    ' (%d samples/sec)' %
                    (cfg.batch_size * cfg.log_interval / duration))

                step = len(train_loader) * epoch + batch_idx
                summary_writer.add_scalar('hmap_loss', hmap_loss.item(), step)
                # summary_writer.add_scalar('occ_loss', occ_loss.item(), step)
                summary_writer.add_scalar('reg_loss', reg_loss.item(), step)
                summary_writer.add_scalar('w_h_loss', w_h_loss.item(), step)
                summary_writer.add_scalar('offset_loss', offsets_loss.item(),
                                          step)
                summary_writer.add_scalar('code_loss', codes_loss.item(), step)
                # summary_writer.add_scalar('cmm_loss', cmm_loss.item(), step)
        return

    def val_map(epoch):
        print_log('\n Val@Epoch: %d' % epoch)
        model.eval()
        torch.cuda.empty_cache()
        max_per_image = 100

        results = {}
        input_scales = {}
        speed_list = []
        with torch.no_grad():
            for inputs in val_loader:
                img_id, inputs = inputs[0]
                start_image_time = time.time()
                segmentations = []
                for scale in inputs:
                    inputs[scale]['image'] = inputs[scale]['image'].to(
                        cfg.device)
                    # if scale == 1. and img_id not in input_scales.keys():  # keep track of the input image Sizes
                    #     _, _, input_h, input_w = inputs[scale]['image'].shape
                    #     input_scales[img_id] = {'h': input_h, 'w': input_w}

                    # dict_tensor = torch.from_numpy(dictionary.astype(np.float32)).to(cfg.device, non_blocking=True)
                    # dict_tensor.requires_grad = False
                    hmap, regs, w_h_, _, _, codes, offsets = model(
                        inputs[scale]['image'])[-1]
                    # hmap, regs, w_h_, _, _, codes = model(inputs[scale]['image'])[-1]
                    output = [hmap, regs, w_h_, codes, offsets]

                    segms = ctsegm_inmodal_code_decode(
                        *output,
                        torch.from_numpy(dictionary.astype(np.float32)).to(
                            cfg.device),
                        K=cfg.test_topk)
                    # segms = ctsegm_shift_code_decode(*output,
                    #                                  torch.from_numpy(dictionary.astype(np.float32)).to(cfg.device),
                    #                                  K=cfg.test_topk)
                    segms = segms.detach().cpu().numpy().reshape(
                        1, -1, segms.shape[2])[0]

                    top_preds = {}
                    for j in range(cfg.n_vertices):
                        segms[:, 2 * j:2 * j + 2] = transform_preds(
                            segms[:, 2 * j:2 * j + 2], inputs[scale]['center'],
                            inputs[scale]['scale'],
                            (inputs[scale]['fmap_w'], inputs[scale]['fmap_h']))
                    segms[:, cfg.n_vertices * 2:cfg.n_vertices * 2 +
                          2] = transform_preds(
                              segms[:,
                                    cfg.n_vertices * 2:cfg.n_vertices * 2 + 2],
                              inputs[scale]['center'], inputs[scale]['scale'],
                              (inputs[scale]['fmap_w'],
                               inputs[scale]['fmap_h']))
                    segms[:, cfg.n_vertices * 2 + 2:cfg.n_vertices * 2 +
                          4] = transform_preds(
                              segms[:, cfg.n_vertices * 2 +
                                    2:cfg.n_vertices * 2 + 4],
                              inputs[scale]['center'], inputs[scale]['scale'],
                              (inputs[scale]['fmap_w'],
                               inputs[scale]['fmap_h']))

                    clses = segms[:, -1]
                    for j in range(val_dataset.num_classes):
                        inds = (clses == j)
                        top_preds[j + 1] = segms[inds, :cfg.n_vertices * 2 +
                                                 5].astype(np.float32)
                        top_preds[j + 1][:, :cfg.n_vertices * 2 + 4] /= scale

                    segmentations.append(top_preds)

                end_image_time = time.time()
                segms_and_scores = {
                    j: np.concatenate([d[j] for d in segmentations], axis=0)
                    for j in range(1, val_dataset.num_classes + 1)
                }
                scores = np.hstack([
                    segms_and_scores[j][:, cfg.n_vertices * 2 + 4]
                    for j in range(1, val_dataset.num_classes + 1)
                ])
                if len(scores) > max_per_image:
                    kth = len(scores) - max_per_image
                    thresh = np.partition(scores, kth)[kth]
                    for j in range(1, val_dataset.num_classes + 1):
                        keep_inds = (
                            segms_and_scores[j][:, cfg.n_vertices * 2 + 4] >=
                            thresh)
                        segms_and_scores[j] = segms_and_scores[j][keep_inds]

                results[img_id] = segms_and_scores
                speed_list.append(end_image_time - start_image_time)

        eval_results = val_dataset.run_eval(results, save_dir=cfg.ckpt_dir)
        print_log(eval_results)
        summary_writer.add_scalar('val_mAP/mAP', eval_results[0], epoch)
        print_log('Average speed on val set:{:.2f}'.format(
            1. / np.mean(speed_list)))

        return eval_results[0]

    print_log('Starting training...')
    for epoch in range(1, cfg.num_epochs + 1):
        start = time.time()
        train_sampler.set_epoch(epoch)
        train(epoch)
        if (cfg.val_interval > 0
                and epoch % cfg.val_interval == 0) or epoch == 2:
            stat = val_map(epoch)
            if stat > best_mAP:
                print('Overall mAP {:.3f} is improving ...'.format(stat))
                print_log(saver.save(model.module.state_dict(), 'checkpoint'))
                best_mAP = stat

        lr_scheduler.step()  # move to here after pytorch1.1.0

        epoch_time = (time.time() - start) / 3600. / 24.
        print_log('ETA:{:.2f} Days'.format(
            (cfg.num_epochs - epoch) * epoch_time))

    summary_writer.close()
def main():
    cfg.device = torch.device('cuda')
    torch.backends.cudnn.benchmark = False

    max_per_image = 100
    num_classes = 80 if cfg.dataset == 'coco' else 4
    dictionary = np.load(cfg.dictionary_file)

    colors = COCO_COLORS if cfg.dataset == 'coco' else DETRAC_COLORS
    names = COCO_NAMES if cfg.dataset == 'coco' else DETRAC_NAMES
    for j in range(len(names)):
        col_ = [c * 255 for c in colors[j]]
        colors[j] = tuple(col_)

    print('Creating model and recover from checkpoint ...')
    if 'hourglass' in cfg.arch:
        model = exkp(n=5,
                     nstack=2,
                     dims=[256, 256, 384, 384, 384, 512],
                     modules=[2, 2, 2, 2, 2, 4],
                     dictionary=torch.from_numpy(dictionary.astype(
                         np.float32)).to(cfg.device))
    else:
        raise NotImplementedError

    model = load_demo_model(model, cfg.ckpt_dir)
    model = model.to(cfg.device)
    model.eval()

    # Loading COCO validation images
    annotation_file = '{}/annotations/instances_{}.json'.format(
        cfg.data_dir, cfg.data_type)
    coco = COCO(annotation_file)

    # Load all annotations
    cats = coco.loadCats(coco.getCatIds())
    nms = [cat['name'] for cat in cats]
    catIds = coco.getCatIds(catNms=nms)
    # imgIds = coco.getImgIds(catIds=catIds)
    imgIds = coco.getImgIds()
    # annIds = coco.getAnnIds(catIds=catIds)
    # all_anns = coco.loadAnns(ids=annIds)
    # print(len(imgIds), imgIds)

    for id in imgIds:
        annt_ids = coco.getAnnIds(imgIds=[id])
        annotations_per_img = coco.loadAnns(ids=annt_ids)
        # print('All annots: ', len(annotations_per_img), annotations_per_img)
        img = coco.loadImgs(id)[0]
        image_path = '%s/images/%s/%s' % (cfg.data_dir, cfg.data_type,
                                          img['file_name'])
        w_img = int(img['width'])
        h_img = int(img['height'])
        if w_img < 1 or h_img < 1:
            continue

        img_original = cv2.imread(image_path)
        img_connect = cv2.imread(image_path)
        img_recon = cv2.imread(image_path)
        print('Image id: ', id)

        for annt in annotations_per_img:
            if annt['iscrowd'] == 1 or type(annt['segmentation']) != list:
                continue

            polygons = get_connected_polygon_using_mask(
                annt['segmentation'], (h_img, w_img),
                n_vertices=cfg.num_vertices,
                closing_max_kernel=60)
            gt_bbox = annt['bbox']
            gt_x1, gt_y1, gt_w, gt_h = gt_bbox
            contour = np.array(polygons).reshape((-1, 2))

            # Downsample the contour to fix number of vertices
            if len(contour) > cfg.num_vertices:
                resampled_contour = resample(contour, num=cfg.num_vertices)
            else:
                resampled_contour = turning_angle_resample(
                    contour, cfg.num_vertices)

            resampled_contour[:, 0] = np.clip(resampled_contour[:, 0], gt_x1,
                                              gt_x1 + gt_w)
            resampled_contour[:, 1] = np.clip(resampled_contour[:, 1], gt_y1,
                                              gt_y1 + gt_h)

            clockwise_flag = check_clockwise_polygon(resampled_contour)
            if not clockwise_flag:
                fixed_contour = np.flip(resampled_contour, axis=0)
            else:
                fixed_contour = resampled_contour.copy()

            # Indexing from the left-most vertex, argmin x-axis
            idx = np.argmin(fixed_contour[:, 0])
            indexed_shape = np.concatenate(
                (fixed_contour[idx:, :], fixed_contour[:idx, :]), axis=0)

            x1, y1, x2, y2 = gt_x1, gt_y1, gt_x1 + gt_w, gt_y1 + gt_h

            # bbox_width, bbox_height = x2 - x1, y2 - y1
            # bbox = [x1, y1, bbox_width, bbox_height]
            bbox_center = np.array([(x1 + x2) / 2., (y1 + y2) / 2.])
            # bbox_center = np.mean(indexed_shape, axis=0)

            centered_shape = indexed_shape - bbox_center

            # visualize resampled points with multiple parts in image side by side
            for cnt in range(len(annt['segmentation'])):
                polys = np.array(annt['segmentation'][cnt]).reshape((-1, 2))
                cv2.polylines(img_original, [polys.astype(np.int32)],
                              True, (10, 10, 255),
                              thickness=2)

            cv2.polylines(img_connect, [indexed_shape.astype(np.int32)],
                          True, (150, 10, 255),
                          thickness=2)

            # learned_val_codes, _ = fast_ista(centered_shape.reshape((1, -1)), dictionary,
            #                          lmbda=0.1, max_iter=60)
            # recon_contour = np.matmul(learned_val_codes, dictionary).reshape((-1, 2))
            # recon_contour = recon_contour + bbox_center
            # cv2.polylines(img_recon, [recon_contour.astype(np.int32)], True, (10, 255, 10), thickness=2)

            # plot gt mean and std
            # image = cv2.imread(image_path)
            # # cv2.ellipse(image, center=(int(contour_mean[0]), int(contour_mean[1])),
            # #             axes=(int(contour_std[0]), int(contour_std[1])),
            # #             angle=0, startAngle=0, endAngle=360, color=(0, 255, 0),
            # #             thickness=2)
            # cv2.rectangle(image, pt1=(int(contour_mean[0] - contour_std[0] / 2.), int(contour_mean[1] - contour_std[1] / 2.)),
            #               pt2=(int(contour_mean[0] + contour_std[0] / 2.), int(contour_mean[1] + contour_std[1] / 2.)),
            #               color=(0, 255, 0), thickness=2)
            # cv2.polylines(image, [fixed_contour.astype(np.int32)], True, (0, 0, 255))
            # cv2.rectangle(image, pt1=(int(min(fixed_contour[:, 0])), int(min(fixed_contour[:, 1]))),
            #               pt2=(int(max(fixed_contour[:, 0])), int(max(fixed_contour[:, 1]))),
            #               color=(255, 0, 0), thickness=2)
            # cv2.imshow('GT segments', image)
            # if cv2.waitKey() & 0xFF == ord('q'):
            #     break

        image = cv2.imread(image_path)
        original_image = image.copy()
        height, width = image.shape[0:2]
        padding = 127 if 'hourglass' in cfg.arch else 31
        imgs = {}
        for scale in cfg.test_scales:
            new_height = int(height * scale)
            new_width = int(width * scale)

            if cfg.img_size > 0:
                img_height, img_width = cfg.img_size, cfg.img_size
                center = np.array([new_width / 2., new_height / 2.],
                                  dtype=np.float32)
                scaled_size = max(height, width) * 1.0
                scaled_size = np.array([scaled_size, scaled_size],
                                       dtype=np.float32)
            else:
                img_height = (new_height | padding) + 1
                img_width = (new_width | padding) + 1
                center = np.array([new_width // 2, new_height // 2],
                                  dtype=np.float32)
                scaled_size = np.array([img_width, img_height],
                                       dtype=np.float32)

            img = cv2.resize(image, (new_width, new_height))
            trans_img = get_affine_transform(center, scaled_size, 0,
                                             [img_width, img_height])
            img = cv2.warpAffine(img, trans_img, (img_width, img_height))

            img = img.astype(np.float32) / 255.
            img -= np.array(
                COCO_MEAN if cfg.dataset == 'coco' else DETRAC_MEAN,
                dtype=np.float32)[None, None, :]
            img /= np.array(COCO_STD if cfg.dataset == 'coco' else DETRAC_STD,
                            dtype=np.float32)[None, None, :]
            img = img.transpose(
                2, 0, 1)[None, :, :, :]  # from [H, W, C] to [1, C, H, W]

            # if cfg.test_flip:
            #     img = np.concatenate((img, img[:, :, :, ::-1].copy()), axis=0)

            imgs[scale] = {
                'image': torch.from_numpy(img).float(),
                'center': np.array(center),
                'scale': np.array(scaled_size),
                'fmap_h': np.array(img_height // 4),
                'fmap_w': np.array(img_width // 4)
            }

        with torch.no_grad():
            segmentations = []
            predicted_codes = []
            start_time = time.time()
            for scale in imgs:
                imgs[scale]['image'] = imgs[scale]['image'].to(cfg.device)

                output = model(imgs[scale]['image'])[-1]
                segms = ctsegm_cmm_decode(*output, K=cfg.test_topk)
                segms = segms.detach().cpu().numpy().reshape(
                    1, -1, segms.shape[2])[0]

                top_preds = {}
                code_preds = {}
                for j in range(cfg.num_vertices):
                    segms[:, 2 * j:2 * j + 2] = transform_preds(
                        segms[:, 2 * j:2 * j + 2], imgs[scale]['center'],
                        imgs[scale]['scale'],
                        (imgs[scale]['fmap_w'], imgs[scale]['fmap_h']))
                segms[:, cfg.num_vertices * 2:cfg.num_vertices * 2 +
                      2] = transform_preds(
                          segms[:,
                                cfg.num_vertices * 2:cfg.num_vertices * 2 + 2],
                          imgs[scale]['center'], imgs[scale]['scale'],
                          (imgs[scale]['fmap_w'], imgs[scale]['fmap_h']))
                segms[:, cfg.num_vertices * 2 + 2:cfg.num_vertices * 2 +
                      4] = transform_preds(
                          segms[:, cfg.num_vertices * 2 +
                                2:cfg.num_vertices * 2 + 4],
                          imgs[scale]['center'], imgs[scale]['scale'],
                          (imgs[scale]['fmap_w'], imgs[scale]['fmap_h']))

                clses = segms[:, -1]
                for j in range(num_classes):
                    inds = (clses == j)
                    top_preds[j + 1] = segms[inds, :cfg.num_vertices * 2 +
                                             5].astype(np.float32)
                    top_preds[j + 1][:, :cfg.num_vertices * 2 + 4] /= scale
                    # code_preds[j + 1] = codes_[inds, :]

                segmentations.append(top_preds)
                predicted_codes.append(code_preds)

            segms_and_scores = {
                j: np.concatenate([d[j] for d in segmentations], axis=0)
                for j in range(1, num_classes + 1)
            }  # a Dict label: segments
            # codes_and_scores = {j: np.concatenate([d[j] for d in predicted_codes], axis=0)
            #                     for j in range(1, num_classes + 1)}  # a Dict label: segments
            scores = np.hstack([
                segms_and_scores[j][:, cfg.num_vertices * 2 + 4]
                for j in range(1, num_classes + 1)
            ])

            if len(scores) > max_per_image:
                kth = len(scores) - max_per_image
                thresh = np.partition(scores, kth)[kth]
                for j in range(1, num_classes + 1):
                    keep_inds = (segms_and_scores[j][:, cfg.num_vertices * 2 +
                                                     4] >= thresh)
                    segms_and_scores[j] = segms_and_scores[j][keep_inds]
                    # codes_and_scores[j] = codes_and_scores[j][keep_inds]

            # Use opencv functions to output a video
            output_image = original_image

            for lab in segms_and_scores:
                for idx in range(len(segms_and_scores[lab])):
                    res = segms_and_scores[lab][idx]
                    # c_ = codes_and_scores[lab][idx]
                    # for res in segms_and_scores[lab]:
                    contour, bbox, score = res[:-5], res[-5:-1], res[-1]
                    bbox[0] = np.clip(bbox[0], 0, w_img)
                    bbox[1] = np.clip(bbox[1], 0, h_img)
                    bbox[2] = np.clip(bbox[2], 0, w_img)
                    bbox[3] = np.clip(bbox[3], 0, h_img)
                    if score > cfg.detect_thres:
                        text = names[lab]  # + ' %.2f' % score
                        # label_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_COMPLEX, thickness=2, fontScale=0.5)
                        polygon = contour.reshape((-1, 2))

                        # use bb tools to draw predictions
                        color = random.choice(COLOR_WORLD)
                        bb.add(output_image, bbox[0], bbox[1], bbox[2],
                               bbox[3], text, color)
                        cv2.polylines(output_image, [polygon.astype(np.int32)],
                                      True,
                                      RGB_DICT[color],
                                      thickness=2)

                        # color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
                        # contour_mean = np.mean(polygon, axis=0)
                        # contour_std = np.std(polygon, axis=0)
                        # center_x, center_y = np.mean(polygon, axis=0).astype(np.int32)
                        # text_location = [bbox[0] + 1, bbox[1] + 1,
                        #                  bbox[1] + label_size[0][0] + 1,
                        #                  bbox[0] + label_size[0][1] + 1]
                        # cv2.rectangle(output_image, pt1=(int(bbox[0]), int(bbox[1])),
                        #               pt2=(int(bbox[2]), int(bbox[3])),
                        #               color=color, thickness=1)
                        # cv2.rectangle(output_image, pt1=(int(np.min(polygon[:, 0])), int(np.min(polygon[:, 1]))),
                        #               pt2=(int(np.max(polygon[:, 0])), int(np.max(polygon[:, 1]))),
                        #               color=(0, 255, 0), thickness=1)
                        # cv2.polylines(output_image, [polygon.astype(np.int32)], True, color, thickness=2)
                        # cv2.putText(output_image, text, org=(int(text_location[0]), int(text_location[3])),
                        #             fontFace=cv2.FONT_HERSHEY_COMPLEX, thickness=2, fontScale=0.5,
                        #             color=(255, 0, 0))
                        # cv2.putText(output_image, text, org=(int(bbox[0]), int(bbox[1])),
                        #             fontFace=cv2.FONT_HERSHEY_COMPLEX, thickness=1, fontScale=0.5,
                        #             color=color)

                        # show the histgram for predicted codes
                        # fig = plt.figure()
                        # plt.plot(np.arange(cfg.n_codes), c_.reshape((-1,)), color='green',
                        #          marker='o', linestyle='dashed', linewidth=2, markersize=6)
                        # plt.ylabel('Value of each coefficient')
                        # plt.xlabel('All predicted {} coefficients'.format(cfg.n_codes))
                        # plt.title('Distribution of the predicted coefficients for {}'.format(text))
                        # plt.show()

            value = [255, 255, 255]
            img_original = cv2.copyMakeBorder(img_original, 0, 0, 0, 10,
                                              cv2.BORDER_CONSTANT, None, value)
            img_connect = cv2.copyMakeBorder(img_connect, 0, 0, 10, 10,
                                             cv2.BORDER_CONSTANT, None, value)
            output_image = cv2.copyMakeBorder(output_image, 0, 0, 10, 0,
                                              cv2.BORDER_CONSTANT, None, value)
            im_cat = np.concatenate((img_original, img_connect, output_image),
                                    axis=1)
            cv2.imshow('GT:Resample:Recons:Predict', im_cat)
            if cv2.waitKey() & 0xFF == ord('q'):
                break