Пример #1
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         cfg)

        pbar.update(1)
def test(segmentation_module, image_path, gpu):
    segmentation_module.eval()

    batch_data = load_image(image_path)
    segSize = (batch_data['img_ori'].shape[0], batch_data['img_ori'].shape[1])
    img_resized_list = batch_data['img_data']

    with torch.no_grad():
        scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
        scores = async_copy_to(scores, gpu)

        for img in img_resized_list:
            feed_dict = batch_data.copy()
            feed_dict['img_data'] = img
            del feed_dict['img_ori']
            del feed_dict['info']
            feed_dict = async_copy_to(feed_dict, gpu)
            # forward pass
            pred_tmp = segmentation_module(feed_dict, segSize=segSize)
            scores = scores + pred_tmp / 1  #len(cfg.DATASET.imgSizes)

        _, pred = torch.max(scores, dim=1)
        pred = as_numpy(pred.squeeze(0).cpu())

    # visualization
    visualize_result((batch_data['img_ori'], batch_data['info']), pred, cfg)
Пример #3
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    for batch_data in loader:
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, DATASET_CONFIG["num_class"], segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(DATASET_CONFIG["imgSizes"])

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        save_img(batch_data['img_ori'], pred)
Пример #4
0
    def run_inference(self, loader):
        rospy.loginfo("Processing image...")
        tic = rospy.get_rostime()

        self.segmentation_module.eval()

        pbar = tqdm(total=len(loader))
        # process data
        for batch_data in loader:
            batch_data = batch_data[0]
            h, w = batch_data['img_ori'].shape[:2]
            segSize = (h, w)
            new_img = np.zeros((h, w, 3))
            img_resized_list = batch_data['img_data']

            with torch.no_grad():
                scores = torch.zeros(1, self.cfg.DATASET.num_class, segSize[0],
                                     segSize[1])
                scores = async_copy_to(scores, self.gpu)

                for img in img_resized_list:
                    feed_dict = batch_data.copy()
                    feed_dict['img_data'] = img
                    del feed_dict['img_ori']
                    del feed_dict['info']
                    feed_dict = async_copy_to(feed_dict, self.gpu)

                    # forward pass
                    pred_tmp = self.segmentation_module(feed_dict,
                                                        segSize=segSize)
                    scores = scores + pred_tmp / len(self.cfg.DATASET.imgSizes)

                #_, pred = torch.max(scores, dim=1)
                #pred = as_numpy(pred.squeeze(0).cpu())
                nparr = as_numpy(scores.squeeze(0).cpu())

            # Putting drivable in green channel
            new_img[:, :, 1] = np.sum(nparr[self.DRIVEABLE], axis=0)
            # Person in red channel
            new_img[:, :, 0] = nparr[self.PERSON, :, :]
            # Converting to uint8
            uint_img = (new_img * 255).astype('uint8')
            # Placing original and segmented image side-by-side
            im_vis = np.concatenate((batch_data['img_ori'], uint_img), axis=1)
            img_msg = self.bridge.cv2_to_imgmsg(im_vis, encoding='rgb8')
            img_msg.header.frame_id = self.frame_id
            img_msg.header.stamp = self.time_ori
            self.seg_pub.publish(img_msg)

            # visualization
            #self.visualize_result(
            #    (batch_data['img_ori'], batch_data['info']),
            #    pred2,
            #    self.cfg
            #)
            pbar.update(1)

        rospy.loginfo('Image latency of %.03f seconds.' %
                      ((rospy.get_rostime() - self.time_ori).to_sec()))
def evaluate(segmentation_module, loader, cfg, gpu, model_name,
             paper_arxiv_id):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()
    evaluator = ADE20KEvaluator(model_name=model_name,
                                paper_arxiv_id=paper_arxiv_id)

    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        torch.cuda.synchronize()
        tic = time.perf_counter()
        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        torch.cuda.synchronize()
        time_meter.update(time.perf_counter() - tic)

        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)

        evaluator.add(outputs=pred.flatten(), targets=seg_label.flatten())

        if evaluator.cache_exists:
            break

        pbar.update(1)
    evaluator.save()
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()
    print(colors)
    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)
                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / 1  #len(cfg.DATASET.imgSizes)
            #print(scores.size())
            #print(scores)
            if DEBUG_CRF:
                unary = scores.data.cpu().numpy()
                unary = np.squeeze(unary, 0)
                unary = -np.log(unary)
                unary = unary.transpose(2, 1, 0)
                w, h, c = unary.shape
                unary = unary.transpose(2, 0, 1).reshape(4, -1)
                unary = np.ascontiguousarray(unary)
                img = np.ascontiguousarray(batch_data['img_ori'])
                d = dcrf.DenseCRF2D(w, h, 4)
                d.setUnaryEnergy(unary)
                d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=img, compat=1)

                q = d.inference(10)
                pred = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
                print(np.unique(pred))

            else:
                _, pred = torch.max(scores, dim=1)
                #print(pred.size())
                #print(torch.unique(pred))
                pred = as_numpy(pred.squeeze(0).cpu())

        # visualization
        visualize_result(
            (batch_data['img_ori'], batch_data['info'], batch_data['gt_mask']),
            pred, cfg)

        pbar.update(1)
Пример #7
0
def cam_test(segmentation_module, cap, args):
    segmentation_module.eval()

    # pbar = tqdm(total=len(loader))
    # for batch_data in loader:
    while cap.isOpened():
        # process data
        # batch_data = batch_data[0]
        # segSize = (batch_data['img_ori'].shape[0],
        #            batch_data['img_ori'].shape[1])
        # img_resized_list = batch_data['img_data']

        ret, frame = cap.read()
        image = frame[:,:,::-1]
        height, width, _ = image.shape
        segSize = (height, width)

        with torch.no_grad():
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, args.gpu)

            img_resized_list = image_pre_process(image, [300, 400, 500])
            # feed_dict = {
            #         'img_data': feed_image
            #         }
            # feed_dict = async_copy_to(feed_dict, args.gpu)
            # pred_tmp = segmentation_module(feed_dict, segSize=segSize)
            # scores = scores + pred_tmp 
            for img in img_resized_list:
                feed_dict = {}
                feed_dict['img_data'] = img
                feed_dict = async_copy_to(feed_dict, args.gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(args.imgSize)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # visualization
        person_mask = pred == 12
        person_mask = person_mask[:, :, np.newaxis]
        person_mask = np.tile(person_mask, (1, 1, 3))
        # viz_res = visualize_display(image, pred)
        viz_frame = bg_image.copy()
        viz_frame[person_mask] = image[person_mask]
        cv2.imshow("VIZ", viz_frame)
        if cv2.waitKey(25) & 0xFF == ord('q'):
            break
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data

        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)
            count = 0
            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                print("Size of feed_dict = " +
                      str(feed_dict['img_data'].size()))
                start_time = time.time()
                feed_dict = async_copy_to(feed_dict, gpu)
                end_time = time.time()
                print("Time to move inputs to gpu = " +
                      str(end_time - start_time))
                # forward pass
                start_time = time.time()
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                print("Pred size = " + str(pred_tmp.size()))
                end_time = time.time()
                print("Time to infer = " + str(end_time - start_time))
                scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)
                count += 1
                if (count):
                    break

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         cfg)

        pbar.update(1)
        break
Пример #9
0
def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):
    segmentation_module.eval()

    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu_id)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu_id)

                # forward pass
                #scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores_tmp = predict_sliding(segmentation_module,
                                             feed_dict, (520, 520),
                                             cfg.DATASET.num_class,
                                             overlap=1.0 / 3.0)
                scores_tmp = nn.functional.interpolate(scores_tmp,
                                                       size=segSize,
                                                       mode='bilinear',
                                                       align_corners=False)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if cfg.VAL.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                os.path.join(cfg.DIR, 'result'))
Пример #10
0
def segment_img(net, data, seg_size, args, valid_masks=None, cutoff=0.2):
    """
    return Tensor (Categories, H, W)
    """
    img_resized_list = data['img_data']
    pred = torch.zeros(1, args.num_class, seg_size[0], seg_size[1])
    for img in img_resized_list:
        feed_dict = data.copy()
        feed_dict['img_data'] = img
        del feed_dict['img_ori']
        del feed_dict['info']
        feed_dict = async_copy_to(feed_dict, 0)

        # forward pass
        pred_tmp = net(feed_dict, segSize=seg_size)
        pred = pred + pred_tmp.cpu() / len(args.imgSize)

    if valid_masks is not None:
        mask = torch.zeros(1, args.num_class, seg_size[0], seg_size[1])
        mask[:, valid_masks, :, :] = 1
        pred *= mask
        pred = pred / (pred.sum(dim=1) + 1e-6)

    # cut off
    pred[pred < cutoff] = 0
    return pred.detach().squeeze()
Пример #11
0
def callback(data):
    #start = timeit.default_timer()
    #global i
    #i+=1
    img1 = bridge.imgmsg_to_cv2(data, "bgr8")
    seg = np.zeros((img1.shape[0], img1.shape[1], 1)).astype(np.uint8)
    seg_size = (img1.shape[0], img1.shape[1])

    img = img1.astype(np.float32)
    img = img.transpose((2, 0, 1))
    img = img_transform(torch.from_numpy(img))
    img = torch.unsqueeze(img, 0)
    feed_dict = async_copy_to({"img_data": img.half()}, 0)
    pred = segmentation_module(feed_dict, segSize=seg_size)

    pred, ind = torch.max(pred, dim=1)
    ind = as_numpy((ind.squeeze()).cpu())

    seg[:, :, 0] = ind

    im = bridge.cv2_to_imgmsg(seg, "mono8")
    #print(np.array_equal(seg, np.int8(seg)))
    #print(np.array_equal(seg, np.int32(seg)))
    seg[seg != 1] = 0
    # cv2.imshow('im', np.int32(seg))
    # cv2.waitKey(1)

    im_label = bridge.cv2_to_imgmsg(np.int32(seg), "32SC1")

    im.header = data.header
    im_label.header = data.header
    pub.publish(im)
    pub_label.publish(im_label)
Пример #12
0
def segment_this_img(f):
    img = imread(f, mode='RGB')
    img = img[:, :, ::-1]  # BGR to RGB!!!
    ori_height, ori_width, _ = img.shape
    img_resized_list = []
    for this_short_size in args.imgSize:
        scale = this_short_size / float(min(ori_height, ori_width))
        target_height, target_width = int(ori_height * scale), int(ori_width *
                                                                   scale)
        target_height = round2nearest_multiple(target_height,
                                               args.padding_constant)
        target_width = round2nearest_multiple(target_width,
                                              args.padding_constant)
        img_resized = cv2.resize(img.copy(), (target_width, target_height))
        img_resized = img_resized.astype(np.float32)
        img_resized = img_resized.transpose((2, 0, 1))
        img_resized = transform(torch.from_numpy(img_resized))
        img_resized = torch.unsqueeze(img_resized, 0)
        img_resized_list.append(img_resized)
    input = dict()
    input['img_ori'] = img.copy()
    input['img_data'] = [x.contiguous() for x in img_resized_list]
    segSize = (img.shape[0], img.shape[1])
    with torch.no_grad():
        pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
        for timg in img_resized_list:
            feed_dict = dict()
            feed_dict['img_data'] = timg.cuda()
            feed_dict = async_copy_to(feed_dict, args.gpu_id)
            # forward pass
            pred_tmp = segmentation_module(feed_dict, segSize=segSize)
            pred = pred + pred_tmp.cpu() / len(args.imgSize)
        _, preds = torch.max(pred, dim=1)
        preds = as_numpy(preds.squeeze(0))
    return preds
Пример #13
0
def test(segmentation_module, loader, args):

    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        img_ori = as_numpy(batch_data['img_ori'])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (img_ori.shape[0], img_ori.shape[1])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            pred = Variable(pred).cuda()

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp / len(args.imgSize)

            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

        precompute_result(batch_data['info'], preds, args)
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], batch_data['info']),
                preds, args)
Пример #14
0
def test(segmentation_module, loader, args):
    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp.cpu() / len(args.imgSize)

            _, preds = torch.max(pred, dim=1)
            preds = as_numpy(preds.squeeze(0))

        # visualization
        visualize_result(
            (batch_data['img_ori'], batch_data['info']),
            preds, args)

        print('[{}] iter {}'
              .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))
Пример #15
0
def test(segmentation_module, loader, args):
    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp.cpu() / len(args.imgSize)

            _, preds = torch.max(pred, dim=1)
            preds = as_numpy(preds.squeeze(0))

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), preds,
                         args)

        print('[{}] iter {}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))
Пример #16
0
def predict(model, img_load, resizeNum, is_silent, gpu=0):
    """
    input:
    model: model
    img_load: A dict of image, which has two keys: 'img_ori' and 'img_data'
    the value of the key 'img_ori' means the original numpy array
    the value of the key 'img_data' is the list of five resize images 
    
    output:
    the mean predictions of the resize image list: 'img_data' 
    """
    starttime = time.time()
    segSize = (img_load['img_ori'].shape[0],
               img_load['img_ori'].shape[1])
    #print('segSize',segSize)
    img_resized_list = img_load['img_data']
    with torch.no_grad():
        scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1], device=torch.device("cuda", gpu))
        
        for img in img_resized_list:
            feed_dict = img_load.copy()
            feed_dict['img_data']=img
            del feed_dict['img_ori']
            #feed_dict = {'img_data': img}
            feed_dict=async_copy_to(feed_dict, gpu)
            
            # forward pass
            pred_tmp = model(feed_dict, segSize = segSize) #shape of pred_temp is (1, 150, height, width)
            scores = scores + pred_tmp / resizeNum
    endtime = time.time()
    if not is_silent:
        print('model inference time: {}s' .format(endtime-starttime))
    return scores
Пример #17
0
def evaluate(segmentation_module, loader, args, dev_id, result_queue):
    segmentation_module.eval()

    for i, data_torch in enumerate(loader):
        data_torch = data_torch[0]  # TODO(LYC):: support batch size > 1
        data_np = as_numpy(data_torch)
        seg_size = data_np['seg_object'].shape[0:2]

        with torch.no_grad():
            pred_ms = {}
            for k in ['material']:
                pred_ms[k] = torch.zeros(1, args.nr_classes[k], *seg_size)

            for img in data_torch['img_resized_list']:
                # forward pass
                feed_dict = async_copy_to({"img": img.unsqueeze(0)}, dev_id)
                pred = segmentation_module(feed_dict, seg_size=seg_size)
                for k in ['material']:
                    pred_ms[k] = pred_ms[k] + pred[k].cpu() / len(args.imgSize)

            for k in ['material']:
                _, p_max = torch.max(pred_ms[k].cpu(), dim=1)
                pred_ms[k] = p_max.squeeze(0)
            pred_ms = as_numpy(pred_ms)

        # calculate accuracy and SEND THEM TO MASTER
        result_queue.put_nowait(get_metrics(pred_ms, data_np))
Пример #18
0
def test(segmentation_module, loader, gpu, gpu_flag, args, progress):
    segmentation_module.eval()
    pbar = tqdm(total=len(loader))
    process_count = 0
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            if gpu_flag:
                scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                # feed_dict['img_data'] = img
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                if gpu_flag:
                    feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                try:
                    pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                    scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)
                except RuntimeError as e:
                    print(
                        '出现运行错误,假如出现CUDA OUT OF MEMORY则为爆显存,会输出错误分割结果,请尝试用CPU处理该图片。错误信息:',
                        e)

            _, pred = torch.max(scores, dim=1)
            if gpu_flag:
                pred = as_numpy(pred.squeeze(0).cpu())
            else:
                pred = as_numpy(pred.squeeze(0))

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         cfg, args)
        process_count += 1
        progress.setValue(int(process_count / len(loader) * 100))
        pbar.update(1)
Пример #19
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    # pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        import os
        chain = batch_data['info'].split("/")[8]
        h_id = batch_data['info'].split("/")[9]
        img = batch_data['info'].split("/")[-1].split(".")[0]
        path = os.path.join("features", chain, h_id, img)
        # print(path)
        features.path = path
        if not os.path.exists(path):
            os.makedirs(path)

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        a = compute(features.feat_2048, features.feat_162, features.path)
        torch.save(a, features.path + "/fts.pt")
        torch.save(features.feat_2048, features.path + "/fts_2048.pt")

        # visualization

        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         cfg)
Пример #20
0
def test(segmentation_module, loader, args):
    segmentation_module.eval()
    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, args.gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + pred_tmp / len(args.imgSize)

            numpy_scores = scores.cpu().numpy()
            cand_indice = house_labelmap
            selected_c = []
            for y in range(numpy_scores.shape[2]):
                y_arr = []
                for x in range(numpy_scores.shape[3]):
                    vec = numpy_scores[0, :, y, x]
                    rank = np.argsort(vec)
                    within_top = rank[-10:]
                    if np.any(np.isin(cand_indice, within_top)):
                        selected_ind = cand_indice[np.argmax(vec[cand_indice])]
                    else:
                        selected_ind = np.argmax(vec)
                    y_arr.append(selected_ind)
                selected_c.append(y_arr)
            pred = np.array(selected_c)

        # visualization
        visualize_result((batch_data['img_ori'], batch_data['info']), pred,
                         args)

        pbar.update(1)
Пример #21
0
def evaluate(segmentation_module, loader, args):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()

    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            pred = Variable(pred).cuda()

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, args.gpu_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp / len(args.imgSize)

            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

        # calculate accuracy
        acc, pix = accuracy(preds, seg_label)
        intersection, union = intersectionAndUnion(preds, seg_label,
                                                   args.num_class)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)
        print('[{}] iter {}, accuracy: {}'.format(
            datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i, acc))

        # visualization
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), preds,
                args)
        if args.precompute:
            precompute_result(batch_data['info'], preds, args)

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(i, _iou))

    print('[Eval Summary]:')
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'.format(
        iou.mean(),
        acc_meter.average() * 100))
Пример #22
0
def get_seg_model_loss(segmentation_module, gpu, content_img):
    segSize = (as_numpy(content_img.squeeze(0).cpu()).shape[0],
               as_numpy(content_img.squeeze(0).cpu()).shape[1])
    feed_dict = {'img_data': content_img.clone()}
    feed_dict = async_copy_to(feed_dict, gpu)
    target_seg = segmentation_module(feed_dict, segSize=segSize)
    seg_loss = SegLoss(target_seg)
    return seg_loss
Пример #23
0
        def closure():
            # correct the values of updated input image
            input_img.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0
            content_score = 0
            seg_score = 0

            for sl in style_losses:
                style_score += sl.loss
            for cl in content_losses:
                content_score += cl.loss
            style_score *= style_weight
            content_score *= content_weight

            if seg_weight != 0:
                #get seg score
                segSize = (as_numpy(input_img.squeeze(0).cpu()).shape[0],
                           as_numpy(input_img.squeeze(0).cpu()).shape[1])
                feed_dict = {'img_data': input_img}
                feed_dict = async_copy_to(feed_dict, gpu)
                input_seg = segmentation_module(feed_dict, segSize=segSize)
                seg_score = seg_loss.forward(input_seg)
                seg_score *= seg_weight

            if seg_weight != 0: loss = style_score + content_score + seg_score
            else: loss = style_score + content_score

            loss.backward(retain_graph=True)

            loss_vs_run['style'].append(style_score.item())
            loss_vs_run['content'].append(content_score.item())
            if seg_weight != 0:
                loss_vs_run['segmentation'].append(seg_score.item())

            if run[0] % 50 == 0:
                print("run {}:".format(run))
                if seg_weight != 0:
                    print(
                        'Style Loss : {:4f} Content Loss: {:4f} Segmentation Loss: {:4f}'
                        .format(style_score.item(), content_score.item(),
                                seg_score.item()))
                else:
                    print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                        style_score.item(), content_score.item()))

                print()
                plt.clf()
                imshow(input_img, title='Output Image')
                plt.savefig(img_savepath +
                            'transferred/%d.png' % int(run[0] / 10))

            run[0] += 1

            return style_score + content_score + seg_score
Пример #24
0
def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):
    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']
        img_ref_resized_list = batch_data['img_refs']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu_id)

            zip_list = zip(img_resized_list, img_ref_resized_list)

            for img, img_refs in zip_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                feed_dict['img_refs'] = img_refs
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu_id)

                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)
                #scores = scores_tmp

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if cfg.VAL.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                os.path.join(cfg.DIR, 'result'))
Пример #25
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    for batch_data in loader:
        print('test')
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = async_copy_to(img, gpu)
                torch.onnx.export(segmentation_module, feed_dict, 'upernet.9.full.onnx', input_names=('input_1',),output_names=('output_1',), opset_version=9)
                break
Пример #26
0
def evaluate(segmentation_module, loader_val, args):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()

    segmentation_module.eval()
    pbar = tqdm(total=len(loader_val))
    for batch_data in loader_val:
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data["mask"][0])
        torch.cuda.synchronize()
        batch_data["image"] = batch_data["image"].unsqueeze(0).cuda()
        #batch_data["mask"][0] = batch_data["mask"][0].cuda()
        #batch_data["mask"][1] = batch_data["mask"][1].cuda()

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, args.gpu)
            feed_dict = batch_data.copy()
            #print(torch.max(feed_dict['image']))   

            # forward pass
            scores, edge, att, loss = segmentation_module(feed_dict, epoch=0, segSize=segSize)
            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
        torch.cuda.synchronize()
        
        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label, args.num_class)
        intersection_meter.update(intersection)
        union_meter.update(union)
        acc_meter.update(acc)
            # visualization
        if True:# args.visualize
            visualize_result(
                (batch_data['image'], seg_label, batch_data["name"]),
                pred, edge, att, args)
        
        #Free up memroy
        #del sal
        
        pbar.update(1)

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {:.4f}'.format(i, _iou))

    print('[Eval Summary]:')
    print('Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'
          .format(iou.mean(), acc_meter.average()*100, time_meter.average()))
Пример #27
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()
    TotalTime = 0.0
    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                start_time = time.time()
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                TotalTime += time.time() - start_time

                scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

            pred = np.uint8(pred)
            img_name = batch_data['info'].split('/')[-1]
            Image.fromarray(pred, mode='L').save(
                os.path.join(cfg.TEST.result, img_name.replace('.jpg',
                                                               '.png')))

        pbar.update(1)
    return TotalTime
def evaluate(segmentation_module, loader, args, gpu_id, result_queue):
    segmentation_module.eval()

    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, gpu_id)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu_id)

                # forward pass
                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(args.imgSize)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   args.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                args)
Пример #29
0
def test(segmentation_module, loader, gpu):
    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        segSize = (batch_data['img_ori'].shape[0],
                   batch_data['img_ori'].shape[1])
        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:

                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                feed_dict = async_copy_to(img, gpu)
                s = datetime.datetime.now()
                pred = segmentation_module(feed_dict)
                e = datetime.datetime.now()
                cost = e - s
                print('latency {} ms, FPS {}'.format(
                    cost.microseconds / 1000.0,
                    1.0 * 1000 * 1000 / cost.microseconds))

                break

        pbar.update(1)
Пример #30
0
def eval(loader_val, segmentation_module, args, crit):
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    loss_meter = AverageMeter()

    segmentation_module.eval()
    for batch_data in loader_val:
        batch_data = batch_data[0]
        
        seg_label = as_numpy(batch_data["mask"][0])
        torch.cuda.synchronize()
        batch_data["image"] = batch_data["image"].unsqueeze(0).cuda()
        print(batch_data["image"].shape)

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, args.gpu)
            print("the score:", scores)
            feed_dict = batch_data.copy()
            

            # forward pass
            scores_tmp, loss = segmentation_module(feed_dict, epoch=0, segSize=segSize)
            scores = scores + scores_tmp
            print("the new score:", scores)
            loss_meter.update(loss)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())
            print("pred shape:", pred.shape)
            
            visualize_result(batch_data["image"].cpu().numpy(), seg_label, pred, args)

        torch.cuda.synchronize()
        # calculate accuracy
        intersection, union = intersectionAndUnion(pred, seg_label, args.num_class)
        intersection_meter.update(intersection)
        union_meter.update(union)
    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        if i >= 1:
            print('class [{}], IoU: {:.4f}'.format(i, _iou))
    print('loss: {:.4f}'.format(loss_meter.average()))
    return iou[1:], loss_meter.average()
    def segmentation_frame(self, img):
        #image = cv2.resize(image, (height, width))
        ori_height, ori_width, _ = img.shape

        this_short_size = 600
        imgMaxSize = 1000
        scale = min(this_short_size / float(min(ori_height, ori_width)),
                    imgMaxSize / float(max(ori_height, ori_width)))
        target_height, target_width = int(ori_height * scale), int(ori_width *
                                                                   scale)

        # to avoid rounding in network
        target_height = self.round2nearest_multiple(target_height,
                                                    self.padding_constant)
        target_width = self.round2nearest_multiple(target_width,
                                                   self.padding_constant)

        # resize
        image = cv2.resize(img.copy(), (target_width, target_height))

        # image transform
        image = self.img_transform(image).cuda()

        image = torch.unsqueeze(image, 0)

        #image = torch.tensor(image).

        print('image ', image.shape)

        #segSize = (image.shape[2],image.shape[3])
        segSize = (ori_height, ori_width)
        scores = torch.zeros(1, args.num_class, segSize[0], segSize[1])
        scores = async_copy_to(scores, args.gpu)
        feed_dict = {}
        feed_dict['img_data'] = image
        # forward pass
        pred_tmp = self.segmentation_module(feed_dict, segSize=segSize)
        scores = scores + pred_tmp
        _, pred = torch.max(scores, dim=1)
        pred = as_numpy(pred.squeeze(0).cpu())
        print('pred ', pred.shape)
        pred_color = colorEncode(pred, colors).astype(np.uint8)
        pub_image = CvBridge().cv2_to_imgmsg(pred_color, "bgr8")
        pub_image.header.frame_id = self.frame_id
        pub_image.header.stamp = self.stamp
        self.image_pub.publish(pub_image)
Пример #32
0
def test(segmentation_module, loader, args):
    segmentation_module.eval()

    for i, data in enumerate(loader):
        # process data
        data = data[0]
        seg_size = data['img_ori'].shape[0:2]

        with torch.no_grad():
            pred_ms = {}
            for k in ['object', 'material']:
                pred_ms[k] = torch.zeros(1, args.nr_classes[k], *seg_size)
            pred_ms['part'] = []
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                n_part = len(broden_dataset.object_part[object_label])
                pred_ms['part'].append(torch.zeros(1, n_part, *seg_size))
            pred_ms['scene'] = torch.zeros(1, args.nr_classes['scene'])

            for img in data['img_data']:
                # forward pass
                feed_dict = async_copy_to({"img": img}, args.gpu_id)
                pred = segmentation_module(feed_dict, seg_size=seg_size)
                for k in ['scene', 'object', 'material']:
                    pred_ms[k] = pred_ms[k] + pred[k].cpu() / len(args.imgSize)
                for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                    pred_ms['part'][idx_part] += pred['part'][idx_part].cpu() / len(args.imgSize)

            pred_ms['scene'] = pred_ms['scene'].squeeze(0)
            for k in ['object', 'material']:
                _, p_max = torch.max(pred_ms[k].cpu(), dim=1)
                pred_ms[k] = p_max.squeeze(0)
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                _, p_max = torch.max(pred_ms['part'][idx_part].cpu(), dim=1)
                pred_ms['part'][idx_part] = p_max.squeeze(0)
            pred_ms = as_numpy(pred_ms)

        visualize_result(data, pred_ms, args)

        print('[{}] iter {}'
              .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))
def evaluate(segmentation_module, loader, args, dev_id, result_queue):

    segmentation_module.eval()

    for i, batch_data in enumerate(loader):
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])

        img_resized_list = batch_data['img_data']

        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, dev_id)

                # forward pass
                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
                pred = pred + pred_tmp.cpu() / len(args.imgSize)

            _, preds = torch.max(pred.data.cpu(), dim=1)
            preds = as_numpy(preds.squeeze(0))

        # calculate accuracy and SEND THEM TO MASTER
        acc, pix = accuracy(preds, seg_label)
        intersection, union = intersectionAndUnion(preds, seg_label, args.num_class)
        result_queue.put_nowait((acc, pix, intersection, union))

        # visualization
        if args.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']),
                preds, args)
Пример #34
0
def evaluate(segmentation_module, loader, args, dev_id, result_queue):
    segmentation_module.eval()

    for i, data_torch in enumerate(loader):
        data_torch = data_torch[0]  # TODO(LYC):: support batch size > 1
        data_np = as_numpy(data_torch)
        seg_size = data_np['seg_object'].shape[0:2]

        with torch.no_grad():
            pred_ms = {}
            for k in ['object', 'material']:
                pred_ms[k] = torch.zeros(1, args.nr_classes[k], *seg_size)
            pred_ms['part'] = []
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                n_part = len(broden_dataset.object_part[object_label])
                pred_ms['part'].append(torch.zeros(1, n_part, *seg_size))
            pred_ms['scene'] = torch.zeros(1, args.nr_classes['scene'])

            for img in data_torch['img_resized_list']:
                # forward pass
                feed_dict = async_copy_to({"img": img.unsqueeze(0)}, dev_id)
                pred = segmentation_module(feed_dict, seg_size=seg_size)
                for k in ['scene', 'object', 'material']:
                    pred_ms[k] = pred_ms[k] + pred[k].cpu() / len(args.imgSize)
                for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                    pred_ms['part'][idx_part] += pred['part'][idx_part].cpu() / len(args.imgSize)

            pred_ms['scene'] = torch.argmax(pred_ms['scene'].squeeze(0))
            for k in ['object', 'material']:
                _, p_max = torch.max(pred_ms[k].cpu(), dim=1)
                pred_ms[k] = p_max.squeeze(0)
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                _, p_max = torch.max(pred_ms['part'][idx_part].cpu(), dim=1)
                pred_ms['part'][idx_part] = p_max.squeeze(0)
            pred_ms = as_numpy(pred_ms)

        # calculate accuracy and SEND THEM TO MASTER
        result_queue.put_nowait(get_metrics(pred_ms, data_np))