コード例 #1
0
def evaluate(segmentation_module, loader, args, dev_id, result_queue):
    segmentation_module.eval()

    for i, data_torch in enumerate(loader):
        data_torch = data_torch[0]  # TODO(LYC):: support batch size > 1
        data_np = as_numpy(data_torch)
        seg_size = data_np['seg_object'].shape[0:2]

        with torch.no_grad():
            pred_ms = {}
            for k in ['material']:
                pred_ms[k] = torch.zeros(1, args.nr_classes[k], *seg_size)

            for img in data_torch['img_resized_list']:
                # forward pass
                feed_dict = async_copy_to({"img": img.unsqueeze(0)}, dev_id)
                pred = segmentation_module(feed_dict, seg_size=seg_size)
                for k in ['material']:
                    pred_ms[k] = pred_ms[k] + pred[k].cpu() / len(args.imgSize)

            for k in ['material']:
                _, p_max = torch.max(pred_ms[k].cpu(), dim=1)
                pred_ms[k] = p_max.squeeze(0)
            pred_ms = as_numpy(pred_ms)

        # calculate accuracy and SEND THEM TO MASTER
        result_queue.put_nowait(get_metrics(pred_ms, data_np))
コード例 #2
0
def test(segmentation_module, loader, args):

    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        img_ori = as_numpy(batch_data['img_ori'])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (img_ori.shape[0], img_ori.shape[1])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            pred = Variable(pred).cuda()

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp / len(args.imgSize)

            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

        precompute_result(batch_data['info'], preds, args)
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], batch_data['info']),
                preds, args)
コード例 #3
0
def test(segmentation_module, loader, args):
	segmentation_module.eval()

	pbar = tqdm(total=len(loader))
	for batch_data in loader:
		# process data
		batch_data = batch_data[0]
		segSize = (batch_data['img_ori'].shape[0],
				   batch_data['img_ori'].shape[1])
		img_resized_list = batch_data['img_data']

		with torch.no_grad():
			scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])

			for img in img_resized_list:
				feed_dict = batch_data.copy()
				feed_dict['img_data'] = img
				del feed_dict['img_ori']
				del feed_dict['info']

				# forward pass
				pred_tmp = segmentation_module(feed_dict, segSize=segSize)
				scores += (pred_tmp.cpu() / len(args.imgSize))
				
				
			pred_prob, pred = torch.max(scores, dim=1)
			pred = as_numpy(pred.squeeze(0).cpu())
			pred_prob = as_numpy(pred_prob.squeeze(0).cpu())

		# visualization
		visualize_result((batch_data['img_ori'], batch_data['info']), pred, pred_prob, args)

		pbar.update(1)
コード例 #4
0
ファイル: vkitti_eval.py プロジェクト: ysymyth/3D-SDN
def evaluate(segmentation_module, loader, args):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()

    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            pred = Variable(pred).cuda()

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp / len(args.imgSize)

            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

        # calculate accuracy
        acc, pix = accuracy(preds, seg_label)
        intersection, union = intersectionAndUnion(preds, seg_label,
                                                   args.num_class)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)
        print('[{}] iter {}, accuracy: {}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i, acc))

        # visualization
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), preds,
                args)
        if args.precompute:
            precompute_result(batch_data['info'], preds, args)

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(i, _iou))

    print('[Eval Summary]:')
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'.format(
        iou.mean(),
        acc_meter.average() * 100))
コード例 #5
0
def get_seg_model_loss(segmentation_module, gpu, content_img):
    segSize = (as_numpy(content_img.squeeze(0).cpu()).shape[0],
               as_numpy(content_img.squeeze(0).cpu()).shape[1])
    feed_dict = {'img_data': content_img.clone()}
    feed_dict = async_copy_to(feed_dict, gpu)
    target_seg = segmentation_module(feed_dict, segSize=segSize)
    seg_loss = SegLoss(target_seg)
    return seg_loss
コード例 #6
0
def evaluate(segmentation_module, loader, cfg, gpu, model_name,
             paper_arxiv_id):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()
    evaluator = ADE20KEvaluator(model_name=model_name,
                                paper_arxiv_id=paper_arxiv_id)

    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        torch.cuda.synchronize()
        tic = time.perf_counter()
        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        torch.cuda.synchronize()
        time_meter.update(time.perf_counter() - tic)

        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)

        evaluator.add(outputs=pred.flatten(), targets=seg_label.flatten())

        if evaluator.cache_exists:
            break

        pbar.update(1)
    evaluator.save()
コード例 #7
0
        def closure():
            # correct the values of updated input image
            input_img.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0
            content_score = 0
            seg_score = 0

            for sl in style_losses:
                style_score += sl.loss
            for cl in content_losses:
                content_score += cl.loss
            style_score *= style_weight
            content_score *= content_weight

            if seg_weight != 0:
                #get seg score
                segSize = (as_numpy(input_img.squeeze(0).cpu()).shape[0],
                           as_numpy(input_img.squeeze(0).cpu()).shape[1])
                feed_dict = {'img_data': input_img}
                feed_dict = async_copy_to(feed_dict, gpu)
                input_seg = segmentation_module(feed_dict, segSize=segSize)
                seg_score = seg_loss.forward(input_seg)
                seg_score *= seg_weight

            if seg_weight != 0: loss = style_score + content_score + seg_score
            else: loss = style_score + content_score

            loss.backward(retain_graph=True)

            loss_vs_run['style'].append(style_score.item())
            loss_vs_run['content'].append(content_score.item())
            if seg_weight != 0:
                loss_vs_run['segmentation'].append(seg_score.item())

            if run[0] % 50 == 0:
                print("run {}:".format(run))
                if seg_weight != 0:
                    print(
                        'Style Loss : {:4f} Content Loss: {:4f} Segmentation Loss: {:4f}'
                        .format(style_score.item(), content_score.item(),
                                seg_score.item()))
                else:
                    print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                        style_score.item(), content_score.item()))

                print()
                plt.clf()
                imshow(input_img, title='Output Image')
                plt.savefig(img_savepath +
                            'transferred/%d.png' % int(run[0] / 10))

            run[0] += 1

            return style_score + content_score + seg_score
コード例 #8
0
def evaluate(segmentation_module, loader_val, args):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()

    segmentation_module.eval()
    pbar = tqdm(total=len(loader_val))
    for batch_data in loader_val:
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data["mask"][0])
        torch.cuda.synchronize()
        batch_data["image"] = batch_data["image"].unsqueeze(0).cuda()
        #batch_data["mask"][0] = batch_data["mask"][0].cuda()
        #batch_data["mask"][1] = batch_data["mask"][1].cuda()

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, args.gpu)
            feed_dict = batch_data.copy()
            #print(torch.max(feed_dict['image']))   

            # forward pass
            scores, edge, att, loss = segmentation_module(feed_dict, epoch=0, segSize=segSize)
            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        torch.cuda.synchronize()
        
        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label, args.num_class)
        intersection_meter.update(intersection)
        union_meter.update(union)
        acc_meter.update(acc)
            # visualization
        if True:# args.visualize
            visualize_result(
                (batch_data['image'], seg_label, batch_data["name"]),
                pred, edge, att, args)
        
        #Free up memroy
        #del sal
        
        pbar.update(1)

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {:.4f}'.format(i, _iou))

    print('[Eval Summary]:')
    print('Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'
          .format(iou.mean(), acc_meter.average()*100, time_meter.average()))
def evaluate(segmentation_module, loader, args):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    cls_ious_meter = AverageMeter()
    cls_mean_iou_meter = AverageMeter()

    eval_model = segmentation_module
    eval_model.eval()

    f = file(args.val_list_file).readlines()

    for i in range(len(loader)):
        batch_data = next(loader)[0]

        # process data
        seg_label = as_numpy(batch_data['seg_label'])

        with torch.no_grad():
            segSize = (seg_label.shape[1], seg_label.shape[2])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            # forward pass
            pred = eval_model(batch_data, segSize=segSize)
            batch_data['data'] = batch_data['data'][:,:,:,::-1]
            pred += eval_model(batch_data, segSize=segSize)[:,:,:,::-1]
            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

    img = misc.imread(args.root_dataset+'/'+f[i].strip().split(' ')[0])
    misc.imsave('./tmp/'+f[i].strip().split(' ')[0].split('/')[-1].replace('.jpg','.png'), preds[:img.shape[0],:img.shape[1]].astype(np.uint8))

    preds = preds[:img.shape[0],:img.shape[1]]
    seg_label = seg_label[:,:img.shape[0],:img.shape[1]]

    # calculate accuracy
    acc, pix = accuracy(preds, seg_label, 255)
    intersection, union = intersectionAndUnion(preds, seg_label, args.num_class, 255)
    acc_meter.update(acc, pix)
    intersection_meter.update(intersection)
    union_meter.update(union)
    mean_iou = (intersection/(union+1e-10))[union!=0].mean()
    print('[{}] iter {}, accuracy: {:.5f}, mIoU: {:.5f}'
              .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i, acc, mean_iou))

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('[{}] class [{}], IoU: {}'.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i, _iou))

    print('[{}] [Eval Summary]:'.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'
          .format(iou.mean(), acc_meter.average()*100))

    return iou, iou.mean()
def evaluate(segmentation_module, loader, args):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    cls_ious_meter = AverageMeter()
    cls_mean_iou_meter = AverageMeter()

    if args.num_gpus > 1:
        eval_model = segmentation_module.module
    else:
        eval_model = segmentation_module
    eval_model.eval()

    for i in range(len(loader)):
        batch_data = next(loader)[0]

        # process data
        seg_label = as_numpy(batch_data['seg_label'])

        with torch.no_grad():
            segSize = (seg_label.shape[1], seg_label.shape[2])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            # forward pass
            pred = eval_model(batch_data, segSize=segSize)
            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

    # calculate accuracy
        acc, pix = accuracy(preds, seg_label, 255)
        intersection, union = intersectionAndUnion(preds, seg_label,
                                                   args.num_class, 255)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)
        mean_iou = (intersection / (union + 1e-10))[union != 0].mean()
        print('[{}] iter {}, accuracy: {:.5f}, mIoU: {:.5f}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i, acc,
            mean_iou))

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('[{}] class [{}], IoU: {}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i, _iou))

    print('[{}] [Eval Summary]:'.format(
        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'.format(
        iou.mean(),
        acc_meter.average() * 100))

    return iou, iou.mean()
コード例 #11
0
def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):
    segmentation_module.eval()

    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu_id)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu_id)

                # forward pass
                #scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores_tmp = predict_sliding(segmentation_module,
                                             feed_dict, (520, 520),
                                             cfg.DATASET.num_class,
                                             overlap=1.0 / 3.0)
                scores_tmp = nn.functional.interpolate(scores_tmp,
                                                       size=segSize,
                                                       mode='bilinear',
                                                       align_corners=False)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if cfg.VAL.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                os.path.join(cfg.DIR, 'result'))
コード例 #12
0
def evaluate(sm, loader_val, args):
    time_meter = AverageMeter()

    sm.eval()

    pbar = tqdm(total=len(loader_val))
    for batch_data in loader_val:
        batch_data = batch_data[0]
        batch_data["image"] = batch_data["image"].unsqueeze(0).cuda()
        torch.cuda.synchronize()
        tic = time.perf_counter()
        with torch.no_grad():
            feed_dict = batch_data.copy()

            # forward pass
            p1 = sm(feed_dict, epoch=0, segSize=True)

            _, pred = torch.max(p1, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

            time_meter.update(time.perf_counter() - tic)
        if args.visualize:
            visualize_result(
                    batch_data['orig'],
                    pv_resized, args)

        torch.cuda.synchronize()

        pbar.update(1)
コード例 #13
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         cfg)

        pbar.update(1)
コード例 #14
0
def test(segmentation_module, image_path, gpu):
    segmentation_module.eval()

    batch_data = load_image(image_path)
    segSize = (batch_data['img_ori'].shape[0], batch_data['img_ori'].shape[1])
    img_resized_list = batch_data['img_data']

    with torch.no_grad():
        scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
        scores = async_copy_to(scores, gpu)

        for img in img_resized_list:
            feed_dict = batch_data.copy()
            feed_dict['img_data'] = img
            del feed_dict['img_ori']
            del feed_dict['info']
            feed_dict = async_copy_to(feed_dict, gpu)
            # forward pass
            pred_tmp = segmentation_module(feed_dict, segSize=segSize)
            scores = scores + pred_tmp / 1  #len(cfg.DATASET.imgSizes)

        _, pred = torch.max(scores, dim=1)
        pred = as_numpy(pred.squeeze(0).cpu())

    # visualization
    visualize_result((batch_data['img_ori'], batch_data['info']), pred, cfg)
コード例 #15
0
def test(segmentation_module, loader, args):
    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp.cpu() / len(args.imgSize)

            _, preds = torch.max(pred, dim=1)
            preds = as_numpy(preds.squeeze(0))

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), preds,
                         args)

        print('[{}] iter {}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))
コード例 #16
0
def segment_this_img(f):
    img = imread(f, mode='RGB')
    img = img[:, :, ::-1]  # BGR to RGB!!!
    ori_height, ori_width, _ = img.shape
    img_resized_list = []
    for this_short_size in args.imgSize:
        scale = this_short_size / float(min(ori_height, ori_width))
        target_height, target_width = int(ori_height * scale), int(ori_width *
                                                                   scale)
        target_height = round2nearest_multiple(target_height,
                                               args.padding_constant)
        target_width = round2nearest_multiple(target_width,
                                              args.padding_constant)
        img_resized = cv2.resize(img.copy(), (target_width, target_height))
        img_resized = img_resized.astype(np.float32)
        img_resized = img_resized.transpose((2, 0, 1))
        img_resized = transform(torch.from_numpy(img_resized))
        img_resized = torch.unsqueeze(img_resized, 0)
        img_resized_list.append(img_resized)
    input = dict()
    input['img_ori'] = img.copy()
    input['img_data'] = [x.contiguous() for x in img_resized_list]
    segSize = (img.shape[0], img.shape[1])
    with torch.no_grad():
        pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
        for timg in img_resized_list:
            feed_dict = dict()
            feed_dict['img_data'] = timg.cuda()
            feed_dict = async_copy_to(feed_dict, args.gpu_id)
            # forward pass
            pred_tmp = segmentation_module(feed_dict, segSize=segSize)
            pred = pred + pred_tmp.cpu() / len(args.imgSize)
        _, preds = torch.max(pred, dim=1)
        preds = as_numpy(preds.squeeze(0))
    return preds
コード例 #17
0
ファイル: segment.py プロジェクト: ShirleyZYJ/Guidance-System
def callback(data):
    #start = timeit.default_timer()
    #global i
    #i+=1
    img1 = bridge.imgmsg_to_cv2(data, "bgr8")
    seg = np.zeros((img1.shape[0], img1.shape[1], 1)).astype(np.uint8)
    seg_size = (img1.shape[0], img1.shape[1])

    img = img1.astype(np.float32)
    img = img.transpose((2, 0, 1))
    img = img_transform(torch.from_numpy(img))
    img = torch.unsqueeze(img, 0)
    feed_dict = async_copy_to({"img_data": img.half()}, 0)
    pred = segmentation_module(feed_dict, segSize=seg_size)

    pred, ind = torch.max(pred, dim=1)
    ind = as_numpy((ind.squeeze()).cpu())

    seg[:, :, 0] = ind

    im = bridge.cv2_to_imgmsg(seg, "mono8")
    #print(np.array_equal(seg, np.int8(seg)))
    #print(np.array_equal(seg, np.int32(seg)))
    seg[seg != 1] = 0
    # cv2.imshow('im', np.int32(seg))
    # cv2.waitKey(1)

    im_label = bridge.cv2_to_imgmsg(np.int32(seg), "32SC1")

    im.header = data.header
    im_label.header = data.header
    pub.publish(im)
    pub_label.publish(im_label)
コード例 #18
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    for batch_data in loader:
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, DATASET_CONFIG["num_class"], segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(DATASET_CONFIG["imgSizes"])

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        save_img(batch_data['img_ori'], pred)
コード例 #19
0
def test(segmentation_module, loader, args):
    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp.cpu() / len(args.imgSize)

            _, preds = torch.max(pred, dim=1)
            preds = as_numpy(preds.squeeze(0))

        # visualization
        visualize_result(
            (batch_data['img_ori'], batch_data['info']),
            preds, args)

        print('[{}] iter {}'
              .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))
コード例 #20
0
def process_predict_good(scores, colors, names, idx_map, is_silent):
    """
    replace colorEncode by numpy way

    input:
    the predictions of model
    
    output:
    the colorize predictions
    """
    _, pred = torch.max(scores, dim=1)
    pred = as_numpy(pred.squeeze(0).cpu())  # shape of pred is (height, width)
    #The predictions for infering distance
    pred = idx_map[pred]
    pred = np.int32(pred)
    pred_color = rock_the_colorencoding(pred, colors)
    if is_silent:
        return pred_color

    pixs = pred.size
    uniques, counts = np.unique(pred, return_counts=True)
    for idx in np.argsort(counts)[::-1]:
        name = names[uniques[idx] + 1]
        ratio = counts[idx] / pixs * 100
        if ratio > 0.1:
            print("  {}: {:.2f}%".format(name, ratio))
    return pred_color
コード例 #21
0
def test(segmentation_module, loader, gpu, gpu_flag, args, progress):
    segmentation_module.eval()
    pbar = tqdm(total=len(loader))
    process_count = 0
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            if gpu_flag:
                scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                # feed_dict['img_data'] = img
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                if gpu_flag:
                    feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                try:
                    pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                    scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)
                except RuntimeError as e:
                    print(
                        '出现运行错误,假如出现CUDA OUT OF MEMORY则为爆显存,会输出错误分割结果,请尝试用CPU处理该图片。错误信息:',
                        e)

            _, pred = torch.max(scores, dim=1)
            if gpu_flag:
                pred = as_numpy(pred.squeeze(0).cpu())
            else:
                pred = as_numpy(pred.squeeze(0))

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         cfg, args)
        process_count += 1
        progress.setValue(int(process_count / len(loader) * 100))
        pbar.update(1)
コード例 #22
0
    def run_inference(self, loader):
        rospy.loginfo("Processing image...")
        tic = rospy.get_rostime()

        self.segmentation_module.eval()

        pbar = tqdm(total=len(loader))
        # process data
        for batch_data in loader:
            batch_data = batch_data[0]
            h, w = batch_data['img_ori'].shape[:2]
            segSize = (h, w)
            new_img = np.zeros((h, w, 3))
            img_resized_list = batch_data['img_data']

            with torch.no_grad():
                scores = torch.zeros(1, self.cfg.DATASET.num_class, segSize[0],
                                     segSize[1])
                scores = async_copy_to(scores, self.gpu)

                for img in img_resized_list:
                    feed_dict = batch_data.copy()
                    feed_dict['img_data'] = img
                    del feed_dict['img_ori']
                    del feed_dict['info']
                    feed_dict = async_copy_to(feed_dict, self.gpu)

                    # forward pass
                    pred_tmp = self.segmentation_module(feed_dict,
                                                        segSize=segSize)
                    scores = scores + pred_tmp / len(self.cfg.DATASET.imgSizes)

                #_, pred = torch.max(scores, dim=1)
                #pred = as_numpy(pred.squeeze(0).cpu())
                nparr = as_numpy(scores.squeeze(0).cpu())

            # Putting drivable in green channel
            new_img[:, :, 1] = np.sum(nparr[self.DRIVEABLE], axis=0)
            # Person in red channel
            new_img[:, :, 0] = nparr[self.PERSON, :, :]
            # Converting to uint8
            uint_img = (new_img * 255).astype('uint8')
            # Placing original and segmented image side-by-side
            im_vis = np.concatenate((batch_data['img_ori'], uint_img), axis=1)
            img_msg = self.bridge.cv2_to_imgmsg(im_vis, encoding='rgb8')
            img_msg.header.frame_id = self.frame_id
            img_msg.header.stamp = self.time_ori
            self.seg_pub.publish(img_msg)

            # visualization
            #self.visualize_result(
            #    (batch_data['img_ori'], batch_data['info']),
            #    pred2,
            #    self.cfg
            #)
            pbar.update(1)

        rospy.loginfo('Image latency of %.03f seconds.' %
                      ((rospy.get_rostime() - self.time_ori).to_sec()))
コード例 #23
0
ファイル: train.py プロジェクト: rexxxx1234/SAUNet-demo
def eval(loader_val, segmentation_module, args, crit):
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    loss_meter = AverageMeter()

    segmentation_module.eval()
    for batch_data in loader_val:
        batch_data = batch_data[0]
        
        seg_label = as_numpy(batch_data["mask"][0])
        torch.cuda.synchronize()
        batch_data["image"] = batch_data["image"].unsqueeze(0).cuda()
        print(batch_data["image"].shape)

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, args.gpu)
            print("the score:", scores)
            feed_dict = batch_data.copy()
            

            # forward pass
            scores_tmp, loss = segmentation_module(feed_dict, epoch=0, segSize=segSize)
            scores = scores + scores_tmp
            print("the new score:", scores)
            loss_meter.update(loss)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
            print("pred shape:", pred.shape)
            
            visualize_result(batch_data["image"].cpu().numpy(), seg_label, pred, args)

        torch.cuda.synchronize()
        # calculate accuracy
        intersection, union = intersectionAndUnion(pred, seg_label, args.num_class)
        intersection_meter.update(intersection)
        union_meter.update(union)
    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        if i >= 1:
            print('class [{}], IoU: {:.4f}'.format(i, _iou))
    print('loss: {:.4f}'.format(loss_meter.average()))
    return iou[1:], loss_meter.average()
コード例 #24
0
ファイル: handseg.py プロジェクト: junwenkwan/hand-seg-tpv
def hand_segmentation(frame, segmentation_module, save=False):
    # Convert to torch.Tensor
    frame_tensor = img_transform(frame)
    frame_tensor = torch.unsqueeze(frame_tensor, 0).cuda()

    # Get sizes
    segSize = (as_numpy(frame).shape[0], as_numpy(frame).shape[1])

    # Forward pass
    pred_tmp = segmentation_module(frame_tensor, segSize=segSize)
    _, pred = torch.max(pred_tmp, dim=1)

    # Convert to numpy.ndarray
    pred = as_numpy(pred.squeeze(0).cpu())

    if save:
        np.savetxt('numpy.txt', pred.astype(int), fmt='%i', delimiter=",")
    return pred
コード例 #25
0
def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):
    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']
        img_ref_resized_list = batch_data['img_refs']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu_id)

            zip_list = zip(img_resized_list, img_ref_resized_list)

            for img, img_refs in zip_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                feed_dict['img_refs'] = img_refs
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu_id)

                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)
                #scores = scores_tmp

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if cfg.VAL.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                os.path.join(cfg.DIR, 'result'))
コード例 #26
0
ファイル: eval_multipro.py プロジェクト: kashyap7x/QGN
def evaluate(segmentation_module, loader, args, dev_id, result_queue):

    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])

        img_resized_list = batch_data['img_data']
        quadtree_resized_list = batch_data['quadtree']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            pred = Variable(pred).cuda()

            for scale, img in enumerate(img_resized_list):
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                if args.eval_mode == 'gt':
                    feed_dict['qtree'] = quadtree_resized_list[scale]
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, dev_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp / len(args.imgSize)

            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(preds, seg_label)
        intersection, union = intersectionAndUnion(preds, seg_label,
                                                   args.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), preds,
                args)
コード例 #27
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()
    print(colors)
    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)
                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / 1  #len(cfg.DATASET.imgSizes)
            #print(scores.size())
            #print(scores)
            if DEBUG_CRF:
                unary = scores.data.cpu().numpy()
                unary = np.squeeze(unary, 0)
                unary = -np.log(unary)
                unary = unary.transpose(2, 1, 0)
                w, h, c = unary.shape
                unary = unary.transpose(2, 0, 1).reshape(4, -1)
                unary = np.ascontiguousarray(unary)
                img = np.ascontiguousarray(batch_data['img_ori'])
                d = dcrf.DenseCRF2D(w, h, 4)
                d.setUnaryEnergy(unary)
                d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=img, compat=1)

                q = d.inference(10)
                pred = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
                print(np.unique(pred))

            else:
                _, pred = torch.max(scores, dim=1)
                #print(pred.size())
                #print(torch.unique(pred))
                pred = as_numpy(pred.squeeze(0).cpu())

        # visualization
        visualize_result(
            (batch_data['img_ori'], batch_data['info'], batch_data['gt_mask']),
            pred, cfg)

        pbar.update(1)
コード例 #28
0
def evaluate(segmentation_module, loader, args, dev_id, result_queue):
    segmentation_module.eval()

    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.MODEL.NUM_CLASSES, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, dev_id)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, dev_id)

                # forward pass
                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(cfg.TRAIN.SCALES)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.MODEL.NUM_CLASSES)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                args)
コード例 #29
0
def evaluate(segmentation_module, loader, args, dev_id, result_queue):

    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, dev_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp.cpu() / len(args.imgSize)

            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(preds, seg_label)
        intersection, union = intersectionAndUnion(preds, seg_label, args.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']),
                preds, args)
コード例 #30
0
def eval(loader_val, segmentation_module, crit):
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    loss_meter = AverageMeter()

    segmentation_module.eval()
    for batch_data in loader_val:
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data["mask"])
        torch.cuda.synchronize()
        batch_data["image"] = batch_data["image"].unsqueeze(0).cuda()
        batch_data["mask"] = batch_data["mask"].cuda()
        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, 4, segSize[0], segSize[1])
            feed_dict = batch_data.copy()
            #print(torch.max(feed_dict['image']))

            # forward pass
            scores, loss = segmentation_module(feed_dict,
                                               epoch=0,
                                               segSize=segSize)
            loss_meter.update(loss)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        torch.cuda.synchronize()
        # calculate accuracy
        intersection, union = intersectionAndUnion(pred, seg_label, 4)
        intersection_meter.update(intersection)
        union_meter.update(union)

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        if i >= 1:
            print('class [{}], IoU: {:.4f}'.format(i, _iou))
    print('loss: {:.4f}'.format(loss_meter.average()))
    return loss_meter.average()
コード例 #31
0
def evaluate(segmentation_module, loader, args, dev_id, result_queue):
    segmentation_module.eval()

    for i, data_torch in enumerate(loader):
        data_torch = data_torch[0]  # TODO(LYC):: support batch size > 1
        data_np = as_numpy(data_torch)
        seg_size = data_np['seg_object'].shape[0:2]

        with torch.no_grad():
            pred_ms = {}
            for k in ['object', 'material']:
                pred_ms[k] = torch.zeros(1, args.nr_classes[k], *seg_size)
            pred_ms['part'] = []
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                n_part = len(broden_dataset.object_part[object_label])
                pred_ms['part'].append(torch.zeros(1, n_part, *seg_size))
            pred_ms['scene'] = torch.zeros(1, args.nr_classes['scene'])

            for img in data_torch['img_resized_list']:
                # forward pass
                feed_dict = async_copy_to({"img": img.unsqueeze(0)}, dev_id)
                pred = segmentation_module(feed_dict, seg_size=seg_size)
                for k in ['scene', 'object', 'material']:
                    pred_ms[k] = pred_ms[k] + pred[k].cpu() / len(args.imgSize)
                for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                    pred_ms['part'][idx_part] += pred['part'][idx_part].cpu() / len(args.imgSize)

            pred_ms['scene'] = torch.argmax(pred_ms['scene'].squeeze(0))
            for k in ['object', 'material']:
                _, p_max = torch.max(pred_ms[k].cpu(), dim=1)
                pred_ms[k] = p_max.squeeze(0)
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                _, p_max = torch.max(pred_ms['part'][idx_part].cpu(), dim=1)
                pred_ms['part'][idx_part] = p_max.squeeze(0)
            pred_ms = as_numpy(pred_ms)

        # calculate accuracy and SEND THEM TO MASTER
        result_queue.put_nowait(get_metrics(pred_ms, data_np))
コード例 #32
0
ファイル: inference.py プロジェクト: riven314/pyqt_speed_test
def process_predict(scores, colors, names, idx_map, is_silent):
    """
    input:
    the predictions of model
    
    output:
    the colorize predictions
    """
    _, pred = torch.max(scores, dim=1)
    pred = as_numpy(pred.squeeze(0).cpu())  # shape of pred is (height, width)
    # grouping label index
    pred = idx_map[pred]
    pred_color = visualize_result(pred, colors, names, is_silent)
    return pred_color
コード例 #33
0
def cam_test(segmentation_module, cap, args):
    segmentation_module.eval()

    # pbar = tqdm(total=len(loader))
    # for batch_data in loader:
    while cap.isOpened():
        # process data
        # batch_data = batch_data[0]
        # segSize = (batch_data['img_ori'].shape[0],
        #            batch_data['img_ori'].shape[1])
        # img_resized_list = batch_data['img_data']

        ret, frame = cap.read()
        image = frame[:,:,::-1]
        height, width, _ = image.shape
        segSize = (height, width)

        with torch.no_grad():
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, args.gpu)

            img_resized_list = image_pre_process(image, [300, 400, 500])
            # feed_dict = {
            #         'img_data': feed_image
            #         }
            # feed_dict = async_copy_to(feed_dict, args.gpu)
            # pred_tmp = segmentation_module(feed_dict, segSize=segSize)
            # scores = scores + pred_tmp 
            for img in img_resized_list:
                feed_dict = {}
                feed_dict['img_data'] = img
                feed_dict = async_copy_to(feed_dict, args.gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(args.imgSize)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # visualization
        person_mask = pred == 12
        person_mask = person_mask[:, :, np.newaxis]
        person_mask = np.tile(person_mask, (1, 1, 3))
        # viz_res = visualize_display(image, pred)
        viz_frame = bg_image.copy()
        viz_frame[person_mask] = image[person_mask]
        cv2.imshow("VIZ", viz_frame)
        if cv2.waitKey(25) & 0xFF == ord('q'):
            break
コード例 #34
0
ファイル: test.py プロジェクト: CSAILVision/unifiedparsing
def test(segmentation_module, loader, args):
    segmentation_module.eval()

    for i, data in enumerate(loader):
        # process data
        data = data[0]
        seg_size = data['img_ori'].shape[0:2]

        with torch.no_grad():
            pred_ms = {}
            for k in ['object', 'material']:
                pred_ms[k] = torch.zeros(1, args.nr_classes[k], *seg_size)
            pred_ms['part'] = []
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                n_part = len(broden_dataset.object_part[object_label])
                pred_ms['part'].append(torch.zeros(1, n_part, *seg_size))
            pred_ms['scene'] = torch.zeros(1, args.nr_classes['scene'])

            for img in data['img_data']:
                # forward pass
                feed_dict = async_copy_to({"img": img}, args.gpu_id)
                pred = segmentation_module(feed_dict, seg_size=seg_size)
                for k in ['scene', 'object', 'material']:
                    pred_ms[k] = pred_ms[k] + pred[k].cpu() / len(args.imgSize)
                for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                    pred_ms['part'][idx_part] += pred['part'][idx_part].cpu() / len(args.imgSize)

            pred_ms['scene'] = pred_ms['scene'].squeeze(0)
            for k in ['object', 'material']:
                _, p_max = torch.max(pred_ms[k].cpu(), dim=1)
                pred_ms[k] = p_max.squeeze(0)
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                _, p_max = torch.max(pred_ms['part'][idx_part].cpu(), dim=1)
                pred_ms['part'][idx_part] = p_max.squeeze(0)
            pred_ms = as_numpy(pred_ms)

        visualize_result(data, pred_ms, args)

        print('[{}] iter {}'
              .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))