Exemplo n.º 1
0
def eval(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not (_run._id is None):
        checkpoint_dir = os.path.join('experiments', str(_run._id),
                                      'checkpoints')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

    # build network
    network = UNet(cfg.model)

    if not (cfg.resume_dir == 'None'):
        model_dict = torch.load(cfg.resume_dir,
                                map_location=lambda storage, loc: storage)
        network.load_state_dict(model_dict)

    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        network = torch.nn.DataParallel(network)
    network.to(device)
    network.eval()

    # data loader
    data_loader = load_dataset('val', cfg.dataset)

    pixel_recall_curve = np.zeros((13))
    plane_recall_curve = np.zeros((13, 3))

    bin_mean_shift = Bin_Mean_Shift(device=device)
    k_inv_dot_xy1 = get_coordinate_map(device)
    instance_parameter_loss = InstanceParameterLoss(k_inv_dot_xy1)
    match_segmentatin = MatchSegmentation()

    with torch.no_grad():
        for iter, sample in enumerate(data_loader):
            image = sample['image'].to(device)
            instance = sample['instance'].to(device)
            gt_seg = sample['gt_seg'].numpy()
            semantic = sample['semantic'].to(device)
            gt_depth = sample['depth'].to(device)
            # gt_plane_parameters = sample['plane_parameters'].to(device)
            valid_region = sample['valid_region'].to(device)
            gt_plane_num = sample['num_planes'].int()
            # gt_plane_instance_parameter = sample['plane_instance_parameter'].numpy()

            # forward pass
            logit, embedding, _, _, param = network(image)

            prob = torch.sigmoid(logit[0])

            # infer per pixel depth using per pixel plane parameter
            _, _, per_pixel_depth = Q_loss(param, k_inv_dot_xy1, gt_depth)

            # fast mean shift
            segmentation, sampled_segmentation, sample_param = bin_mean_shift.test_forward(
                prob, embedding[0], param, mask_threshold=0.1)

            # since GT plane segmentation is somewhat noise, the boundary of plane in GT is not well aligned,
            # we thus use avg_pool_2d to smooth the segmentation results
            b = segmentation.t().view(1, -1, 192, 256)
            pooling_b = torch.nn.functional.avg_pool2d(b, (7, 7),
                                                       stride=1,
                                                       padding=(3, 3))
            b = pooling_b.view(-1, 192 * 256).t()
            segmentation = b

            # infer instance depth
            instance_loss, instance_depth, instance_abs_disntace, instance_parameter = \
                instance_parameter_loss(segmentation, sampled_segmentation, sample_param,
                                        valid_region, gt_depth, False)

            # greedy match of predict segmentation and ground truth segmentation using cross entropy
            # to better visualization
            matching = match_segmentatin(segmentation, prob.view(-1, 1),
                                         instance[0], gt_plane_num)

            # return cluster results
            predict_segmentation = segmentation.cpu().numpy().argmax(axis=1)

            # reindexing to matching gt segmentation for better visualization
            matching = matching.cpu().numpy().reshape(-1)
            used = set([])
            max_index = max(matching) + 1
            for i, a in zip(range(len(matching)), matching):
                if a in used:
                    matching[i] = max_index
                    max_index += 1
                else:
                    used.add(a)
            predict_segmentation = matching[predict_segmentation]

            # mask out non planar region
            predict_segmentation[prob.cpu().numpy().reshape(-1) <= 0.1] = 20
            predict_segmentation = predict_segmentation.reshape(192, 256)

            # visualization and evaluation
            h, w = 192, 256
            image = tensor_to_image(image.cpu()[0])
            semantic = semantic.cpu().numpy().reshape(h, w)
            mask = (prob > 0.1).float().cpu().numpy().reshape(h, w)
            gt_seg = gt_seg.reshape(h, w)
            depth = instance_depth.cpu().numpy()[0, 0].reshape(h, w)
            per_pixel_depth = per_pixel_depth.cpu().numpy()[0, 0].reshape(h, w)

            # use per pixel depth for non planar region
            depth = depth * (predict_segmentation != 20) + per_pixel_depth * (
                predict_segmentation == 20)
            gt_depth = gt_depth.cpu().numpy()[0, 0].reshape(h, w)

            # evaluation plane segmentation
            pixelStatistics, planeStatistics = eval_plane_prediction(
                predict_segmentation, gt_seg, depth, gt_depth)

            pixel_recall_curve += np.array(pixelStatistics)
            plane_recall_curve += np.array(planeStatistics)

            print("pixel and plane recall of test image ", iter)
            print(pixel_recall_curve / float(iter + 1))
            print(plane_recall_curve[:, 0] / plane_recall_curve[:, 1])
            print("********")

            # visualization convert labels to color image
            # change non-planar regions to zero, so non-planar regions use the black color
            gt_seg += 1
            gt_seg[gt_seg == 21] = 0
            predict_segmentation += 1
            predict_segmentation[predict_segmentation == 21] = 0

            gt_seg_image = cv2.resize(
                np.stack(
                    [colors[gt_seg, 0], colors[gt_seg, 1], colors[gt_seg, 2]],
                    axis=2), (w, h))
            pred_seg = cv2.resize(
                np.stack([
                    colors[predict_segmentation,
                           0], colors[predict_segmentation, 1],
                    colors[predict_segmentation, 2]
                ],
                         axis=2), (w, h))

            # blend image
            blend_pred = (pred_seg * 0.7 + image * 0.3).astype(np.uint8)
            blend_gt = (gt_seg_image * 0.7 + image * 0.3).astype(np.uint8)

            semantic = cv2.resize((semantic * 255).astype(np.uint8), (w, h))
            semantic = cv2.cvtColor(semantic, cv2.COLOR_GRAY2BGR)

            mask = cv2.resize((mask * 255).astype(np.uint8), (w, h))
            mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

            depth_diff = np.abs(gt_depth - depth)
            depth_diff[gt_depth == 0.] = 0

            # visualize depth map as PlaneNet
            depth_diff = np.clip(depth_diff / 5 * 255, 0, 255).astype(np.uint8)
            depth_diff = cv2.cvtColor(cv2.resize(depth_diff, (w, h)),
                                      cv2.COLOR_GRAY2BGR)

            depth = 255 - np.clip(depth / 5 * 255, 0, 255).astype(np.uint8)
            depth = cv2.cvtColor(cv2.resize(depth, (w, h)), cv2.COLOR_GRAY2BGR)

            gt_depth = 255 - np.clip(gt_depth / 5 * 255, 0, 255).astype(
                np.uint8)
            gt_depth = cv2.cvtColor(cv2.resize(gt_depth, (w, h)),
                                    cv2.COLOR_GRAY2BGR)

            image_1 = np.concatenate((image, pred_seg, gt_seg_image), axis=1)
            image_2 = np.concatenate((image, blend_pred, blend_gt), axis=1)
            image_3 = np.concatenate((image, mask, semantic), axis=1)
            image_4 = np.concatenate((depth_diff, depth, gt_depth), axis=1)
            image = np.concatenate((image_1, image_2, image_3, image_4),
                                   axis=0)

            # cv2.imshow('image', image)
            # cv2.waitKey(0)
            # cv2.imwrite("%d_segmentation.png"%iter, image)

        print("========================================")
        print("pixel and plane recall of all test image")
        print(pixel_recall_curve / len(data_loader))
        print(plane_recall_curve[:, 0] / plane_recall_curve[:, 1])
        print("****************************************")
Exemplo n.º 2
0
def predict(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # build network
    network = UNet(cfg.model)

    if not (cfg.resume_dir == 'None'):
        model_dict = torch.load(cfg.resume_dir,
                                map_location=lambda storage, loc: storage)
        network.load_state_dict(model_dict)

    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        network = torch.nn.DataParallel(network)
    network.to(device)
    network.eval()

    transforms = tf.Compose([
        tf.ToTensor(),
        tf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    bin_mean_shift = Bin_Mean_Shift(device=device)
    k_inv_dot_xy1 = get_coordinate_map(device)
    instance_parameter_loss = InstanceParameterLoss(k_inv_dot_xy1)

    h, w = 192, 256

    with torch.no_grad():
        image = cv2.imread(cfg.image_path)
        # the network is trained with 192*256 and the intrinsic parameter is set as ScanNet
        image = cv2.resize(image, (w, h))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = transforms(image)
        image = image.to(device).unsqueeze(0)
        # forward pass
        logit, embedding, _, _, param = network(image)

        prob = torch.sigmoid(logit[0])

        # infer per pixel depth using per pixel plane parameter, currently Q_loss need a dummy gt_depth as input
        _, _, per_pixel_depth = Q_loss(param, k_inv_dot_xy1,
                                       torch.ones_like(logit))

        # fast mean shift
        segmentation, sampled_segmentation, sample_param = bin_mean_shift.test_forward(
            prob, embedding[0], param, mask_threshold=0.1)

        # since GT plane segmentation is somewhat noise, the boundary of plane in GT is not well aligned,
        # we thus use avg_pool_2d to smooth the segmentation results
        b = segmentation.t().view(1, -1, h, w)
        pooling_b = torch.nn.functional.avg_pool2d(b, (7, 7),
                                                   stride=1,
                                                   padding=(3, 3))
        b = pooling_b.view(-1, h * w).t()
        segmentation = b

        # infer instance depth
        instance_loss, instance_depth, instance_abs_disntace, instance_parameter = instance_parameter_loss(
            segmentation, sampled_segmentation, sample_param,
            torch.ones_like(logit), torch.ones_like(logit), False)

        # infer instance normal
        _, _, manhattan_norm, instance_norm = surface_normal_loss(
            param, torch.ones_like(image), None)

        # return cluster results
        predict_segmentation = segmentation.cpu().numpy().argmax(axis=1)

        # mask out non planar region
        predict_segmentation[prob.cpu().numpy().reshape(-1) <= 0.1] = 20
        predict_segmentation = predict_segmentation.reshape(h, w)

        # visualization and evaluation
        image = tensor_to_image(image.cpu()[0])
        mask = (prob > 0.1).float().cpu().numpy().reshape(h, w)
        depth = instance_depth.cpu().numpy()[0, 0].reshape(h, w)
        per_pixel_depth = per_pixel_depth.cpu().numpy()[0, 0].reshape(h, w)
        manhattan_normal_2d = manhattan_norm.cpu().numpy().reshape(
            h, w, 3) * np.expand_dims((predict_segmentation != 20), -1)
        instance_normal_2d = instance_norm.cpu().numpy().reshape(
            h, w, 3) * np.expand_dims((predict_segmentation != 20), -1)
        pcd = o3d.geometry.PointCloud()
        norm_colors = cm.Set3(predict_segmentation.reshape(w * h))
        pcd.points = o3d.utility.Vector3dVector(
            np.reshape(manhattan_normal_2d, (w * h, 3)))
        pcd.colors = o3d.utility.Vector3dVector(norm_colors[:, 0:3])
        o3d.io.write_point_cloud('./manhattan_sphere.ply', pcd)
        pcd.points = o3d.utility.Vector3dVector(
            np.reshape(instance_normal_2d, (w * h, 3)))
        o3d.io.write_point_cloud('./instance_sphere.ply', pcd)
        normal_plot = cv2.resize(
            ((manhattan_normal_2d + 1) * 128).astype(np.uint8), (w, h))

        # use per pixel depth for non planar region
        depth = depth * (predict_segmentation !=
                         20) + per_pixel_depth * (predict_segmentation == 20)

        # change non planar to zero, so non planar region use the black color
        predict_segmentation += 1
        predict_segmentation[predict_segmentation == 21] = 0

        pred_seg = cv2.resize(
            np.stack([
                colors[predict_segmentation, 0],
                colors[predict_segmentation, 1], colors[predict_segmentation,
                                                        2]
            ],
                     axis=2), (w, h))

        # blend image
        blend_pred = (pred_seg * 0.7 + image * 0.3).astype(np.uint8)

        mask = cv2.resize((mask * 255).astype(np.uint8), (w, h))
        mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

        # visualize depth map as PlaneNet
        depth_real = cv2.cvtColor(cv2.resize(depth, (w, h)),
                                  cv2.COLOR_GRAY2BGR)
        depth = 255 - np.clip(depth / 5 * 255, 0, 255).astype(np.uint8)
        depth = cv2.cvtColor(cv2.resize(depth, (w, h)), cv2.COLOR_GRAY2BGR)

        Camera_fx = 481.2
        Camera_fy = 480.0
        Camera_cx = -319.5
        Camera_cy = 239.50

        ##
        # Camera_fx = 535.4
        # Camera_fy = 539.2
        # Camera_cx = 320.1
        # Camera_cy = 247.6

        # correct parameters
        Camera_fx = 518.8
        Camera_fy = 518.8
        Camera_cx = 320
        Camera_cy = 240

        # cabin parrameters
        Camera_fx = 440.66
        Camera_fy = 531.72
        Camera_cx = 361.77
        Camera_cy = 215.79
        points = []
        points_instance = []
        ratio_x = 640 / 256.0
        ratio_y = 480 / 192.0
        scalingFactor = 1.0
        for v in range(h):
            for u in range(w):
                color = image[v, u]
                color_instance = pred_seg[v, u]
                Z = (depth_real[v, u] / scalingFactor)[0]
                if Z == 0: continue
                X = (u - Camera_cx / ratio_x) * Z / Camera_fx * ratio_x
                Y = (v - Camera_cy / ratio_y) * Z / Camera_fy * ratio_y
                # points.append("%f %f %f %d %d %d 0\n" % (X, Y, Z, color[2], color[1], color[0]))
                points_instance.append("%f %f %f %d %d %d 0\n" %
                                       (X, Y, Z, color_instance[2],
                                        color_instance[1], color_instance[0]))
        file1 = open('./pointCloud_instance.ply', "w")
        file1.write('''ply
                       format ascii 1.0
                       element vertex %d
                       property float x
                       property float y
                       property float z
                       property uchar red
                       property uchar green
                       property uchar blue
                       property uchar alpha
                       end_header
                       %s
                       ''' % (len(points_instance), "".join(points_instance)))
        file1.close()

        image = np.concatenate(
            (image, pred_seg, blend_pred, mask, depth, normal_plot), axis=1)

        cv2.imshow('image', image)
        cv2.waitKey(0)
Exemplo n.º 3
0
def train(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not (_run._id is None):
        checkpoint_dir = os.path.join(_run.observers[0].basedir, str(_run._id),
                                      'checkpoints')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

    # build network
    network = UNet(cfg.model)

    if not (cfg.resume_dir == 'None'):
        model_dict = torch.load(cfg.resume_dir,
                                map_location=lambda storage, loc: storage)
        network.load_state_dict(model_dict)

    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        network = torch.nn.DataParallel(network)
    network.to(device)

    # set up optimizers
    optimizer = get_optimizer(network.parameters(), cfg.solver)

    # data loader
    data_loader = load_dataset('train', cfg.dataset)

    # save losses per epoch
    history = {
        'losses': [],
        'losses_pull': [],
        'losses_push': [],
        'losses_binary': [],
        'losses_depth': [],
        'ioues': [],
        'rmses': []
    }

    network.train(not cfg.model.fix_bn)

    bin_mean_shift = Bin_Mean_Shift(device=device)
    k_inv_dot_xy1 = get_coordinate_map(device)
    instance_parameter_loss = InstanceParameterLoss(k_inv_dot_xy1)

    # main loop
    for epoch in range(cfg.num_epochs):
        batch_time = AverageMeter()
        losses = AverageMeter()
        losses_pull = AverageMeter()
        losses_push = AverageMeter()
        losses_binary = AverageMeter()
        losses_depth = AverageMeter()
        losses_normal = AverageMeter()
        losses_instance = AverageMeter()
        ioues = AverageMeter()
        rmses = AverageMeter()
        instance_rmses = AverageMeter()
        mean_angles = AverageMeter()

        tic = time.time()
        for iter, sample in enumerate(data_loader):
            image = sample['image'].to(device)
            instance = sample['instance'].to(device)
            semantic = sample['semantic'].to(device)
            gt_depth = sample['depth'].to(device)
            gt_seg = sample['gt_seg'].to(device)
            gt_plane_parameters = sample['plane_parameters'].to(device)
            valid_region = sample['valid_region'].to(device)
            gt_plane_instance_parameter = sample[
                'plane_instance_parameter'].to(device)

            # forward pass
            logit, embedding, _, _, param = network(image)

            segmentations, sample_segmentations, sample_params, centers, sample_probs, sample_gt_segs = \
                bin_mean_shift(logit, embedding, param, gt_seg)

            # calculate loss
            loss, loss_pull, loss_push, loss_binary, loss_depth, loss_normal, loss_parameters, loss_pw, loss_instance \
                = 0., 0., 0., 0., 0., 0., 0., 0., 0.
            batch_size = image.size(0)
            for i in range(batch_size):
                _loss, _loss_pull, _loss_push = hinge_embedding_loss(
                    embedding[i:i + 1], sample['num_planes'][i:i + 1],
                    instance[i:i + 1], device)

                _loss_binary = class_balanced_cross_entropy_loss(
                    logit[i], semantic[i])

                _loss_normal, mean_angle = surface_normal_loss(
                    param[i:i + 1], gt_plane_parameters[i:i + 1],
                    valid_region[i:i + 1])

                _loss_L1 = parameter_loss(param[i:i + 1],
                                          gt_plane_parameters[i:i + 1],
                                          valid_region[i:i + 1])
                _loss_depth, rmse, infered_depth = Q_loss(
                    param[i:i + 1], k_inv_dot_xy1, gt_depth[i:i + 1])

                if segmentations[i] is None:
                    continue

                _instance_loss, instance_depth, instance_abs_disntace, _ = \
                    instance_parameter_loss(segmentations[i], sample_segmentations[i], sample_params[i],
                                            valid_region[i:i+1], gt_depth[i:i+1])

                _loss += _loss_binary + _loss_depth + _loss_normal + _instance_loss + _loss_L1

                # planar segmentation iou
                prob = torch.sigmoid(logit[i])
                mask = (prob > 0.5).float().cpu().numpy()
                iou = eval_iou(mask, semantic[i].cpu().numpy())
                ioues.update(iou * 100)
                instance_rmses.update(instance_abs_disntace.item())
                rmses.update(rmse.item())
                mean_angles.update(mean_angle.item())

                loss += _loss
                loss_pull += _loss_pull
                loss_push += _loss_push
                loss_binary += _loss_binary
                loss_depth += _loss_depth
                loss_normal += _loss_normal
                loss_instance += _instance_loss

            loss /= batch_size
            loss_pull /= batch_size
            loss_push /= batch_size
            loss_binary /= batch_size
            loss_depth /= batch_size
            loss_normal /= batch_size
            loss_instance /= batch_size

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update loss
            losses.update(loss.item())
            losses_pull.update(loss_pull.item())
            losses_push.update(loss_push.item())
            losses_binary.update(loss_binary.item())
            losses_depth.update(loss_depth.item())
            losses_normal.update(loss_normal.item())
            losses_instance.update(loss_instance.item())

            # update time
            batch_time.update(time.time() - tic)
            tic = time.time()

            if iter % cfg.print_interval == 0:
                _log.info(
                    f"[{epoch:2d}][{iter:5d}/{len(data_loader):5d}] "
                    f"Time: {batch_time.val:.2f} ({batch_time.avg:.2f}) "
                    f"Loss: {losses.val:.4f} ({losses.avg:.4f}) "
                    f"Pull: {losses_pull.val:.4f} ({losses_pull.avg:.4f}) "
                    f"Push: {losses_push.val:.4f} ({losses_push.avg:.4f}) "
                    f"INS: {losses_instance.val:.4f} ({losses_instance.avg:.4f}) "
                    f"Binary: {losses_binary.val:.4f} ({losses_binary.avg:.4f}) "
                    f"IoU: {ioues.val:.2f} ({ioues.avg:.2f}) "
                    f"LN: {losses_normal.val:.4f} ({losses_normal.avg:.4f}) "
                    f"AN: {mean_angles.val:.4f} ({mean_angles.avg:.4f}) "
                    f"Depth: {losses_depth.val:.4f} ({losses_depth.avg:.4f}) "
                    f"INSDEPTH: {instance_rmses.val:.4f} ({instance_rmses.avg:.4f}) "
                    f"RMSE: {rmses.val:.4f} ({rmses.avg:.4f}) ")

        _log.info(f"* epoch: {epoch:2d}\t"
                  f"Loss: {losses.avg:.6f}\t"
                  f"Pull: {losses_pull.avg:.6f}\t"
                  f"Push: {losses_push.avg:.6f}\t"
                  f"Binary: {losses_binary.avg:.6f}\t"
                  f"Depth: {losses_depth.avg:.6f}\t"
                  f"IoU: {ioues.avg:.2f}\t"
                  f"RMSE: {rmses.avg:.4f}\t")

        # save history
        history['losses'].append(losses.avg)
        history['losses_pull'].append(losses_pull.avg)
        history['losses_push'].append(losses_push.avg)
        history['losses_binary'].append(losses_binary.avg)
        history['losses_depth'].append(losses_depth.avg)
        history['ioues'].append(ioues.avg)
        history['rmses'].append(rmses.avg)

        # save checkpoint
        if not (_run._id is None):
            torch.save(
                network.state_dict(),
                os.path.join(checkpoint_dir, f"network_epoch_{epoch}.pt"))
            pickle.dump(
                history, open(os.path.join(checkpoint_dir, 'history.pkl'),
                              'wb'))
Exemplo n.º 4
0
def predict(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # build network
    network = UNet(cfg.model)

    if not (cfg.resume_dir == 'None'):
        model_dict = torch.load(cfg.resume_dir,
                                map_location=lambda storage, loc: storage)
        network.load_state_dict(model_dict)

    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        network = torch.nn.DataParallel(network)
    network.to(device)
    network.eval()

    transforms = tf.Compose([
        tf.ToTensor(),
        tf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    bin_mean_shift = Bin_Mean_Shift(device=device)
    k_inv_dot_xy1 = get_coordinate_map(device)
    instance_parameter_loss = InstanceParameterLoss(k_inv_dot_xy1)

    h, w = 192, 256

    with torch.no_grad():
        image = cv2.imread(cfg.image_path)
        raw_h, raw_w, _ = image.shape
        # the network is trained with 192*256 and the intrinsic parameter is set as ScanNet
        image = cv2.resize(image, (w, h))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = transforms(image)
        image = image.to(device).unsqueeze(0)
        # forward pass
        logit, embedding, _, _, param = network(image)

        prob = torch.sigmoid(logit[0])

        # infer per pixel depth using per pixel plane parameter, currently Q_loss need a dummy gt_depth as input
        _, _, per_pixel_depth = Q_loss(param, k_inv_dot_xy1,
                                       torch.ones_like(logit))

        # fast mean shift
        segmentation, sampled_segmentation, sample_param = bin_mean_shift.test_forward(
            prob, embedding[0], param, mask_threshold=0.1)

        # since GT plane segmentation is somewhat noise, the boundary of plane in GT is not well aligned,
        # we thus use avg_pool_2d to smooth the segmentation results
        b = segmentation.t().view(1, -1, h, w)
        pooling_b = torch.nn.functional.avg_pool2d(b, (7, 7),
                                                   stride=1,
                                                   padding=(3, 3))
        b = pooling_b.view(-1, h * w).t()
        segmentation = b

        # infer instance depth
        instance_loss, instance_depth, instance_abs_disntace, instance_parameter = instance_parameter_loss(
            segmentation, sampled_segmentation, sample_param,
            torch.ones_like(logit), torch.ones_like(logit), False)

        # return cluster results
        predict_segmentation = segmentation.cpu().numpy().argmax(axis=1)

        # mask out non planar region
        predict_segmentation[prob.cpu().numpy().reshape(-1) <= 0.1] = 20
        predict_segmentation = predict_segmentation.reshape(h, w)

        # visualization and evaluation
        image = tensor_to_image(image.cpu()[0])
        mask = (prob > 0.1).float().cpu().numpy().reshape(h, w)
        depth = instance_depth.cpu().numpy()[0, 0].reshape(h, w)
        per_pixel_depth = per_pixel_depth.cpu().numpy()[0, 0].reshape(h, w)

        # use per pixel depth for non planar region
        depth_clean = depth * (predict_segmentation !=
                               20) + 0 * (predict_segmentation == 20)
        depth = depth * (predict_segmentation !=
                         20) + per_pixel_depth * (predict_segmentation == 20)
        Image.fromarray(depth * 1000).convert('I').resize(
            (raw_w, raw_h)).save(cfg.image_path.replace('.jpg', '_depth.png'))
        Image.fromarray(depth_clean * 1000).convert('I').resize(
            (raw_w,
             raw_h)).save(cfg.image_path.replace('.jpg', '_depth_clean.png'))
        # np.save("510.npy", Image.fromarray(depth*1000).convert('I'))

        # change non planar to zero, so non planar region use the black color
        predict_segmentation += 1
        predict_segmentation[predict_segmentation == 21] = 0

        pred_seg = cv2.resize(
            np.stack([
                colors[predict_segmentation, 0],
                colors[predict_segmentation, 1], colors[predict_segmentation,
                                                        2]
            ],
                     axis=2), (w, h))

        # blend image
        blend_pred = (pred_seg * 0.7 + image * 0.3).astype(np.uint8)

        mask = cv2.resize((mask * 255).astype(np.uint8), (w, h))
        mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

        # visualize depth map as PlaneNet
        depth = 255 - np.clip(depth / 5 * 255, 0, 255).astype(np.uint8)
        depth = cv2.cvtColor(cv2.resize(depth, (w, h)), cv2.COLOR_GRAY2BGR)

        image = np.concatenate((image, pred_seg, blend_pred, mask, depth),
                               axis=1)

        cv2.imwrite(cfg.image_path.replace('.jpg', '_out.jpg'), image)
Exemplo n.º 5
0
def eval(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint_dir = os.path.join('experiments', str(_run._id), 'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # build network
    network = UNet(cfg.model)

    if not cfg.resume_dir == 'None':
        model_dict = torch.load(cfg.resume_dir)
        network.load_state_dict(model_dict)

    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        network = torch.nn.DataParallel(network)
    network.to(device)
    network.eval()

    bin_mean_shift = Bin_Mean_Shift()
    instance_parameter_loss = InstanceParameterLoss()

    h, w = 192, 256

    with torch.no_grad():
        for i in range(1):
            image = cv2.imread(cfg.input_image)
            # the network is trained with 192*256 and the intrinsic parameter is set as ScanNet
            image = cv2.resize(image, (w, h))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image)
            #
            image = transforms(image)
            image = image.cuda().unsqueeze(0)
            # forward pass
            logit, embedding, _, _, param = network(image)

            prob = torch.sigmoid(logit[0])

            # infer per pixel depth using per pixel plane parameter, currently Q_loss need a dummy gt_depth as input
            _, _, per_pixel_depth = Q_loss(param, k_inv_dot_xy1,
                                           torch.ones_like(logit))

            # fast mean shift
            segmentation, sampled_segmentation, sample_param = bin_mean_shift.test_forward(
                prob, embedding[0], param, mask_threshold=0.1)

            # since GT plane segmentation is somewhat noise, the boundary of plane in GT is not well aligned,
            # we thus use avg_pool_2d to smooth the segmentation results
            b = segmentation.t().view(1, -1, h, w)
            pooling_b = torch.nn.functional.avg_pool2d(b, (7, 7),
                                                       stride=1,
                                                       padding=(3, 3))
            b = pooling_b.view(-1, h * w).t()
            segmentation = b

            # infer instance depth
            instance_loss, instance_depth, instance_abs_disntace, instance_parameter = \
                instance_parameter_loss(segmentation, sampled_segmentation, sample_param,
                                        torch.ones_like(logit), torch.ones_like(logit), False)

            # return cluster results
            predict_segmentation = segmentation.cpu().numpy().argmax(axis=1)

            # mask out non planar region
            predict_segmentation[prob.cpu().numpy().reshape(-1) <= 0.1] = 20
            predict_segmentation = predict_segmentation.reshape(h, w)

            # visualization and evaluation
            image = tensor_to_image(image.cpu()[0])
            mask = (prob > 0.1).float().cpu().numpy().reshape(h, w)
            depth = instance_depth.cpu().numpy()[0, 0].reshape(h, w)
            per_pixel_depth = per_pixel_depth.cpu().numpy()[0, 0].reshape(h, w)

            # use per pixel depth for non planar region
            depth = depth * (predict_segmentation != 20) + per_pixel_depth * (
                predict_segmentation == 20)

            # change non planar to zero, so non planar region use the black color
            predict_segmentation += 1
            predict_segmentation[predict_segmentation == 21] = 0

            pred_seg = cv2.resize(
                np.stack([
                    colors[predict_segmentation,
                           0], colors[predict_segmentation, 1],
                    colors[predict_segmentation, 2]
                ],
                         axis=2), (w, h))

            # blend image
            blend_pred = (pred_seg * 0.7 + image * 0.3).astype(np.uint8)

            mask = cv2.resize((mask * 255).astype(np.uint8), (w, h))
            mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

            # visualize depth map as PlaneNet
            depth = 255 - np.clip(depth / 5 * 255, 0, 255).astype(np.uint8)
            depth = cv2.cvtColor(cv2.resize(depth, (w, h)), cv2.COLOR_GRAY2BGR)

            image = np.concatenate((image, pred_seg, blend_pred, mask, depth),
                                   axis=1)

            cv2.imshow('image', image)
            cv2.waitKey(0)
Exemplo n.º 6
0
def eval(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint_dir = os.path.join('experiments', str(_run._id), 'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # build network
    network = UNet(cfg.model)

    if not cfg.resume_dir == 'None':
        model_dict = torch.load(cfg.resume_dir)
        network.load_state_dict(model_dict)

    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        network = torch.nn.DataParallel(network)
    network.to(device)
    network.eval()

    bin_mean_shift = Bin_Mean_Shift()
    instance_parameter_loss = InstanceParameterLoss()

    h, w = 192, 256

    f = open(cfg.dataset_csv, 'r')
    lines = f.readlines()
    f.close()
    lines.pop(0) # header

    with torch.no_grad():
        #for i, image_path in enumerate(tqdm(files)):
        for line in tqdm(lines):
            splits = line.split(',')
            image_path = splits[0]

            image_name = image_path.split('/')[-1]
            image_path = os.path.join(cfg.image_path, image_name)
            #image = cv2.imread(cfg.input_image)
            image = cv2.imread(image_path)
            oh, ow, _ = image.shape
            # the network is trained with 192*256 and the intrinsic parameter is set as ScanNet
            image = cv2.resize(image, (w, h))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image)
            #
            image = transforms(image)
            image = image.cuda().unsqueeze(0)
            # forward pass
            logit, embedding, _, _, param = network(image)

            prob = torch.sigmoid(logit[0])

            # infer per pixel depth using per pixel plane parameter, currently Q_loss need a dummy gt_depth as input
            _, _, per_pixel_depth = Q_loss(param, k_inv_dot_xy1, torch.ones_like(logit))

            # fast mean shift
            segmentation, sampled_segmentation, sample_param = bin_mean_shift.test_forward(prob, embedding[0], param, mask_threshold=0.1)

            # since GT plane segmentation is somewhat noise, the boundary of plane in GT is not well aligned,
            # we thus use avg_pool_2d to smooth the segmentation results
            b = segmentation.t().view(1, -1, h, w)
            pooling_b = torch.nn.functional.avg_pool2d(b, (7, 7), stride=1, padding=(3, 3))
            b = pooling_b.view(-1, h*w).t()
            segmentation = b

            # infer instance depth
            instance_loss, instance_depth, instance_abs_disntace, instance_parameter = \
                instance_parameter_loss(segmentation, sampled_segmentation, sample_param,
                                        torch.ones_like(logit), torch.ones_like(logit), False)

            # return cluster results
            predict_segmentation = segmentation.cpu().numpy().argmax(axis=1)
            #import pdb; pdb.set_trace()

            # mask out non planar region
            predict_segmentation[prob.cpu().numpy().reshape(-1) <= 0.1] = 20
            predict_segmentation = predict_segmentation.reshape(h, w)

            # visualization and evaluation
            image = tensor_to_image(image.cpu()[0])
            mask = (prob > 0.1).float().cpu().numpy().reshape(h, w)
            depth = instance_depth.cpu().numpy()[0, 0].reshape(h, w)
            per_pixel_depth = per_pixel_depth.cpu().numpy()[0, 0].reshape(h, w)

            # use per pixel depth for non planar region
            depth = depth * (predict_segmentation != 20) + per_pixel_depth * (predict_segmentation == 20)

            # change non planar to zero, so non planar region use the black color
            predict_segmentation += 1
            predict_segmentation[predict_segmentation==21] = 0

            num_ins = predict_segmentation.max()
            seg_path = os.path.join(cfg.output_path, 'seg')
            bboxes = []
            for j in range(1, num_ins + 1):
                m = (predict_segmentation == j).astype(np.float)
                m = m * 255.0
                m = m.astype(np.uint8)
                m = cv2.resize(m, (ow, oh), interpolation=cv2.INTER_NEAREST)
                x, y = np.where(m >= 127)
                if len(x) == 0:
                    continue
                min_x = np.min(x)
                max_x = np.max(x)
                min_y = np.min(y)
                max_y = np.max(y)
                bbox = [min_x, min_y, max_x, max_y, 1.0]
                bboxes.append(bbox)

            #tqdm.write("{} {}".format(ow, oh))
            #tqdm.write(str(bboxes))
            file_obj = open(os.path.join(cfg.output_path, image_name.split('.')[0] + '.pkl'), 'wb') 
            pickle.dump(bboxes, file_obj)
            file_obj.close()

            pred_seg = cv2.resize(np.stack([colors[predict_segmentation, 0],
                                            colors[predict_segmentation, 1],
                                            colors[predict_segmentation, 2]], axis=2), (w, h))

            # blend image
            blend_pred = (pred_seg * 0.7 + image * 0.3).astype(np.uint8)

            mask = cv2.resize((mask * 255).astype(np.uint8), (w, h))
            mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

            # visualize depth map as PlaneNet
            depth = 255 - np.clip(depth / 5 * 255, 0, 255).astype(np.uint8)
            depth = cv2.cvtColor(cv2.resize(depth, (w, h)), cv2.COLOR_GRAY2BGR)

            image = np.concatenate((image, pred_seg, blend_pred, mask, depth), axis=1)

            mask_path = os.path.join(cfg.output_path, 'mask')
Exemplo n.º 7
0
def predict(_run, _log):
    cfg = edict(_run.config)

    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # build network
    network = UNet(cfg.model)

    if not (cfg.resume_dir == 'None'):
        model_dict = torch.load(cfg.resume_dir,
                                map_location=lambda storage, loc: storage)
        network.load_state_dict(model_dict)

    # load nets into gpu
    if cfg.num_gpus > 1 and torch.cuda.is_available():
        network = torch.nn.DataParallel(network)
    network.to(device)
    network.eval()

    transforms = tf.Compose([
        tf.ToTensor(),
        tf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    bin_mean_shift = Bin_Mean_Shift(device=device)
    k_inv_dot_xy1 = get_coordinate_map(device)
    instance_parameter_loss = InstanceParameterLoss(k_inv_dot_xy1)

    h, w = 192, 256

    focal_length = 517.97
    offset_x = 320
    offset_y = 240

    K = [[focal_length, 0, offset_x], [0, focal_length, offset_y], [0, 0, 1]]

    K_inv = np.linalg.inv(np.array(K))

    K_inv_dot_xy_1 = np.zeros((3, h, w))

    for y in range(h):
        for x in range(w):
            yy = float(y) / h * 480
            xx = float(x) / w * 640

            ray = np.dot(K_inv, np.array([xx, yy, 1]).reshape(3, 1))
            K_inv_dot_xy_1[:, y, x] = ray[:, 0]

    with torch.no_grad():
        cam = cv2.VideoCapture(0)
        i = 0
        while i < 1:
            input('Press Enter to capture Image')
            start_time_1 = time.time()
            return_value, image = cam.read()
            i += 1
        #del(cam)
        cam.release()
        #image = cv2.imread("cam_images/opencv0.png")
        # the network is trained with 192*256 and the intrinsic parameter is set as ScanNet
        image = cv2.resize(image, (w, h))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = transforms(image)
        image = image.to(device).unsqueeze(0)
        # forward pass
        logit, embedding, _, _, param = network(image)

        prob = torch.sigmoid(logit[0])

        # infer per pixel depth using per pixel plane parameter, currently Q_loss need a dummy gt_depth as input
        _, _, per_pixel_depth = Q_loss(param, k_inv_dot_xy1,
                                       torch.ones_like(logit))

        # fast mean shift
        segmentation, sampled_segmentation, sample_param = bin_mean_shift.test_forward(
            prob, embedding[0], param, mask_threshold=0.1)

        print("segmentation- ", segmentation.shape)
        # since GT plane segmentation is somewhat noise, the boundary of plane in GT is not well aligned,
        # we thus use avg_pool_2d to smooth the segmentation results
        b = segmentation.t().view(1, -1, h, w)
        pooling_b = torch.nn.functional.avg_pool2d(b, (7, 7),
                                                   stride=1,
                                                   padding=(3, 3))
        b = pooling_b.view(-1, h * w).t()
        segmentation = b
        print("segmentation-boundary -", segmentation.shape)
        # infer instance depth
        instance_loss, instance_depth, instance_abs_disntace, instance_parameter = instance_parameter_loss(
            segmentation, sampled_segmentation, sample_param,
            torch.ones_like(logit), torch.ones_like(logit), False)

        # return cluster results
        segmentation = segmentation.cpu().numpy().argmax(axis=1)
        print("segmentation-shape -", segmentation.shape)
        # mask out non planar region
        segmentation[prob.cpu().numpy().reshape(-1) <= 0.1] = 20
        segmentation = segmentation.reshape(h, w)
        print("segments-reshape - ", segmentation.shape)
        # visualization and evaluation
        #one
        image = tensor_to_image(image.cpu()[0])
        mask = (prob > 0.1).float().cpu().numpy().reshape(h, w)
        depth = instance_depth.cpu().numpy()[0, 0].reshape(h, w)
        per_pixel_depth = per_pixel_depth.cpu().numpy()[0, 0].reshape(h, w)

        # use per pixel depth for non planar region
        depth = depth * (segmentation != 20) + per_pixel_depth * (segmentation
                                                                  == 20)

        # change non planar to zero, so non planar region use the black color
        segmentation += 1
        segmentation[segmentation == 21] = 0
        print("segmentation-zero - ", segmentation.shape)
        # two
        pred_seg = cv2.resize(
            np.stack([
                colors[segmentation, 0], colors[segmentation, 1],
                colors[segmentation, 2]
            ],
                     axis=2), (w, h))

        # blend image
        # three
        blend_pred = (pred_seg * 0.4 + image * 0.6).astype(np.uint8)

        #four
        mask = cv2.resize((mask * 255).astype(np.uint8), (w, h))
        mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)

        # visualize depth map as PlaneNet
        #five
        depth = 255 - np.clip(depth / 5 * 255, 0, 255).astype(np.uint8)
        depth = cv2.cvtColor(cv2.resize(depth, (w, h)), cv2.COLOR_GRAY2BGR)
        print("image -", image.shape)
        print("pred_seg -", pred_seg.shape)
        print("blend_pred- ", blend_pred.shape)
        print("mask -", mask.shape)
        print("depth -", depth.shape)
        image_c = np.concatenate((image, pred_seg, blend_pred, mask, depth),
                                 axis=1)

    imageFilename = str(index) + '_model_texture.png'
    cv2.imwrite(folder + '/' + imageFilename, image_c)

    print("segmentation-", segmentation.shape)
    print("depth-", depth.shape)
    cv2.imwrite("segmentation.jpg", segmentation)
    # create face from segmentation
    faces = []
    for y in range(h - 1):
        for x in range(w - 1):
            segmentIndex = segmentation[y, x]
            # ignore non planar region
            if segmentIndex == 0:
                continue

            # add face if three pixel has same segmentatioin
            depths = [depth[y][x], depth[y + 1][x], depth[y + 1][x + 1]]
            if segmentation[y + 1, x] == segmentIndex and segmentation[
                    y + 1, x + 1] == segmentIndex and np.array(
                        depths).min() > 0 and np.array(depths).max() < 10:
                faces.append((x, y, x, y + 1, x + 1, y + 1))

            depths = [depth[y][x], depth[y][x + 1], depth[y + 1][x + 1]]
            if segmentation[y][x + 1] == segmentIndex and segmentation[y + 1][
                    x + 1] == segmentIndex and np.array(
                        depths).min() > 0 and np.array(depths).max() < 10:
                faces.append((x, y, x + 1, y + 1, x + 1, y))

    with open(folder + '/' + str(index) + '_model.ply', 'w') as f:
        header = """ply
format ascii 1.0
comment VCGLIB generated
comment TextureFile """
        header += imageFilename
        header += """
element vertex """
        header += str(h * w)
        header += """
property float x
property float y
property float z
property uchar red                                     { start of vertex color }
property uchar green
property uchar blue
element face """
        header += str(len(faces))
        header += """
property list uchar int vertex_indices
property list uchar float texcoord
end_header
"""
        f.write(header)
        for y in range(h):
            for x in range(w):
                segmentIndex = segmentation[y][x]
                if segmentIndex == 20:
                    f.write("0.0 0.0 0.0\n")
                    continue
                ray = K_inv_dot_xy_1[:, y, x]
                X, Y, Z = ray * depth[y, x]
                R, G, B = image[y, x]
                f.write(
                    str(X) + ' ' + str(Y) + ' ' + str(Z) + ' ' + str(R) + ' ' +
                    str(G) + ' ' + str(B) + '\n')

        for face in faces:
            f.write('3 ')
            for c in range(3):
                f.write(str(face[c * 2 + 1] * w + face[c * 2]) + ' ')
            f.write('6 ')
            for c in range(3):
                f.write(
                    str(float(face[c * 2]) / w) + ' ' +
                    str(1 - float(face[c * 2 + 1]) / h) + ' ')
            f.write('\n')
        f.close()
        pass

    # visuvalization

    filename = 'predictions/0_model.ply'
    print("3D mesh generated in ",
          "--- %s seconds ---" % (time.time() - start_time_1))
    print(filename)

    mesh = pv.read(filename)
    cpos = mesh.plot()

    plotter = pv.Plotter(off_screen=True)
    plotter.add_mesh(mesh)
    plotter.show(screenshot="myscreenshot.png")
    rimage = cv2.imread("predictions/0_model_texture.png")
    cv2.imshow("image,depth,mask,segmentation -", rimage)
    cv2.waitKey(0)
    #closing all open windows
    cv2.destroyAllWindows()
    return
Exemplo n.º 8
0
class Predictor():
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.cfg = Cfg()
        # build network
        self.network = UNet(self.cfg)

        self.model_dict = torch.load('pretrained.pt',
                                     map_location=lambda storage, loc: storage)
        self.network.load_state_dict(self.model_dict)

        self.network.to(self.device)
        self.network.eval()

        self.transforms = tf.Compose([
            tf.ToTensor(),
            tf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        self.bin_mean_shift = Bin_Mean_Shift(device=self.device)
        self.k_inv_dot_xy1 = get_coordinate_map(self.device)
        self.instance_parameter_loss = InstanceParameterLoss(
            self.k_inv_dot_xy1)

        self.h, self.w = 192, 256

    def predict(self, image_path):
        with torch.no_grad():
            image = cv2.imread(image_path)
            # the network is trained with 192*256 and the intrinsic parameter is set as ScanNet
            image = cv2.resize(image, (self.w, self.h))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image)
            image = self.transforms(image)
            image = image.to(self.device).unsqueeze(0)

            # forward pass
            logit, embedding, _, _, param = self.network(image)

            prob = torch.sigmoid(logit[0])

            # infer per pixel depth using per pixel plane parameter, currently Q_loss need a dummy gt_depth as input
            _, _, per_pixel_depth = Q_loss(param, self.k_inv_dot_xy1,
                                           torch.ones_like(logit))

            # fast mean shift
            segmentation, sampled_segmentation, sample_param = self.bin_mean_shift.test_forward(
                prob, embedding[0], param, mask_threshold=0.1)

            # since GT plane segmentation is somewhat noise, the boundary of plane in GT is not well aligned,
            # we thus use avg_pool_2d to smooth the segmentation results
            b = segmentation.t().view(1, -1, self.h, self.w)
            pooling_b = torch.nn.functional.avg_pool2d(b, (7, 7),
                                                       stride=1,
                                                       padding=(3, 3))
            b = pooling_b.view(-1, self.h * self.w).t()
            segmentation = b

            # infer instance depth
            instance_loss, instance_depth, instance_abs_disntace, instance_parameter = self.instance_parameter_loss(
                segmentation, sampled_segmentation, sample_param,
                torch.ones_like(logit), torch.ones_like(logit), False)

            # return cluster results
            predict_segmentation = segmentation.cpu().numpy().argmax(axis=1)

            # mask out non planar region
            predict_segmentation[prob.cpu().numpy().reshape(-1) <= 0.1] = 20
            predict_segmentation = predict_segmentation.reshape(self.h, self.w)

            # visualization and evaluation
            image = tensor_to_image(image.cpu()[0])
            mask = (prob > 0.1).float().cpu().numpy().reshape(self.h, self.w)
            depth = instance_depth.cpu().numpy()[0, 0].reshape(self.h, self.w)
            per_pixel_depth = per_pixel_depth.cpu().numpy()[0, 0].reshape(
                self.h, self.w)

            # use per pixel depth for non planar region
            depth = depth * (predict_segmentation != 20) + per_pixel_depth * (
                predict_segmentation == 20)

            # change non planar to zero, so non planar region use the black color
            predict_segmentation += 1
            predict_segmentation[predict_segmentation == 21] = 0

            return depth, predict_segmentation, mask, instance_parameter